nemo_automodel.components.speculative.eagle.peagle_data
nemo_automodel.components.speculative.eagle.peagle_data
Conditional-On-Distribution (COD) sampling for P-EAGLE parallel drafting.
P-EAGLE (https://github.com/vllm-project/speculators/pull/480) trains all K
draft depths in a single parallel forward rather than EAGLE-3’s sequential
test-time-training (TTT) unroll. To keep the flattened multi-depth sequence
affordable, COD subsamples deeper depths with geometric decay: depth 0 keeps all
n positions, depth 1 keeps n * r, depth 2 n * r**2, …, so the
attention cost drops from O((nK)**2) to O((n * sum r**i)**2).
This is a verbatim port of speculators’ generate_cod_sample_indices so the
on-disk draft trains against the exact distribution vLLM’s parallel-drafting
runtime samples at inference.
Module Contents
Functions
API
Assign every COD-sampled element to one of num_segments segments.
Implements the sequence-partitioning assignment of P-EAGLE’s Algorithm 1
(arXiv:2602.01469). A single COD sequence is split into S segments so the
parallel-drafting loss can be accumulated with separate forward/backward
passes — one per segment — dropping peak attention memory and letting the
draft train on long contexts that would otherwise OOM in the single flat
forward (:class:PEagleTrainerModule forward, sequence_partitions == 1).
The algorithm’s three phases collapse to a closed form here because every COD
element carries its chain start anchor_pos and depth (reference
position anchor_pos + depth):
- Phase 1 (depths 0 and 1, by position).
A[p] = max{s : B_s <= p}with boundariesB = {0, L/S, ..., L}. A depth-0 element sits atanchor_pos; a depth-1 element atanchor_pos + 1. - Phase 2 (depths >= 2, by dependency).
A^g[p] = A^{g-1}[p-1]: each deeper element inherits the segment of its parent one position back in the same rollout. Because a rollout shares oneanchor_posacross depths, the inheritance chain terminates at the depth-1 ancestor, so every element atdepth >= 1lands in the bucket ofanchor_pos + 1.
Phase 3 (causal completion — each segment also reads every depth-0 position up to its right boundary as key/value context) is applied by the caller when it assembles a segment’s flat input, not encoded in the returned assignment.
Parameters:
Chain-start position per element, shape [total_sampled].
COD round per element, shape [total_sampled].
Original (padded) sequence length L.
Number of segments S (>= 1).
Returns: torch.Tensor
Per-element segment id in [0, num_segments), shape [total_sampled].
Generate COD sampling indices for one sequence.
Parameters:
Length of the (padded) sequence.
Binary mask of valid training positions, shape [1, seq_len]
or [seq_len]. Padding / unsupervised positions must be 0.
Number of parallel prediction depths K.
Geometric decay ratio r in (0, 1].
Minimum retention ratio floor.
Drop position 0 from every depth>=1 candidate pool
(it has no preceding token to predict, so its chain-start anchor
would be negative). Keep True unless the caller guarantees
loss_mask[0] == 0.
Returns: tuple[torch.Tensor, torch.Tensor]
Tuple of:
anchor_pos: Start position in the original sequence each sampled
element’s chain began from, shape [total_sampled].
depth: Which COD round each element belongs to, shape
[total_sampled]. The reference (target) position of an
element is anchor_pos + depth.