nemo_automodel.components.speculative.eagle.peagle_data

View as Markdown

Conditional-On-Distribution (COD) sampling for P-EAGLE parallel drafting.

P-EAGLE (https://github.com/vllm-project/speculators/pull/480) trains all K draft depths in a single parallel forward rather than EAGLE-3’s sequential test-time-training (TTT) unroll. To keep the flattened multi-depth sequence affordable, COD subsamples deeper depths with geometric decay: depth 0 keeps all n positions, depth 1 keeps n * r, depth 2 n * r**2, …, so the attention cost drops from O((nK)**2) to O((n * sum r**i)**2).

This is a verbatim port of speculators’ generate_cod_sample_indices so the on-disk draft trains against the exact distribution vLLM’s parallel-drafting runtime samples at inference.

Module Contents

Functions

NameDescription
assign_cod_segmentsAssign every COD-sampled element to one of num_segments segments.
generate_cod_sample_indicesGenerate COD sampling indices for one sequence.

API

nemo_automodel.components.speculative.eagle.peagle_data.assign_cod_segments(
anchor_pos: torch.Tensor,
depth: torch.Tensor,
seq_length: int,
num_segments: int
) -> torch.Tensor

Assign every COD-sampled element to one of num_segments segments.

Implements the sequence-partitioning assignment of P-EAGLE’s Algorithm 1 (arXiv:2602.01469). A single COD sequence is split into S segments so the parallel-drafting loss can be accumulated with separate forward/backward passes — one per segment — dropping peak attention memory and letting the draft train on long contexts that would otherwise OOM in the single flat forward (:class:PEagleTrainerModule forward, sequence_partitions == 1).

The algorithm’s three phases collapse to a closed form here because every COD element carries its chain start anchor_pos and depth (reference position anchor_pos + depth):

  • Phase 1 (depths 0 and 1, by position). A[p] = max{s : B_s <= p} with boundaries B = {0, L/S, ..., L}. A depth-0 element sits at anchor_pos; a depth-1 element at anchor_pos + 1.
  • Phase 2 (depths >= 2, by dependency). A^g[p] = A^{g-1}[p-1]: each deeper element inherits the segment of its parent one position back in the same rollout. Because a rollout shares one anchor_pos across depths, the inheritance chain terminates at the depth-1 ancestor, so every element at depth >= 1 lands in the bucket of anchor_pos + 1.

Phase 3 (causal completion — each segment also reads every depth-0 position up to its right boundary as key/value context) is applied by the caller when it assembles a segment’s flat input, not encoded in the returned assignment.

Parameters:

anchor_pos
torch.Tensor

Chain-start position per element, shape [total_sampled].

depth
torch.Tensor

COD round per element, shape [total_sampled].

seq_length
int

Original (padded) sequence length L.

num_segments
int

Number of segments S (>= 1).

Returns: torch.Tensor

Per-element segment id in [0, num_segments), shape [total_sampled].

nemo_automodel.components.speculative.eagle.peagle_data.generate_cod_sample_indices(
seq_length: int,
loss_mask: torch.Tensor,
num_depths: int = 8,
down_sample_ratio: float = 0.7,
down_sample_ratio_min: float = 0.2,
filter_position_zero: bool = True
) -> tuple[torch.Tensor, torch.Tensor]

Generate COD sampling indices for one sequence.

Parameters:

seq_length
int

Length of the (padded) sequence.

loss_mask
torch.Tensor

Binary mask of valid training positions, shape [1, seq_len] or [seq_len]. Padding / unsupervised positions must be 0.

num_depths
intDefaults to 8

Number of parallel prediction depths K.

down_sample_ratio
floatDefaults to 0.7

Geometric decay ratio r in (0, 1].

down_sample_ratio_min
floatDefaults to 0.2

Minimum retention ratio floor.

filter_position_zero
boolDefaults to True

Drop position 0 from every depth>=1 candidate pool (it has no preceding token to predict, so its chain-start anchor would be negative). Keep True unless the caller guarantees loss_mask[0] == 0.

Returns: tuple[torch.Tensor, torch.Tensor]

Tuple of: anchor_pos: Start position in the original sequence each sampled element’s chain began from, shape [total_sampled]. depth: Which COD round each element belongs to, shape [total_sampled]. The reference (target) position of an element is anchor_pos + depth.