nemo_automodel.components.speculative.eagle.peagle_trainer
nemo_automodel.components.speculative.eagle.peagle_trainer
P-EAGLE (parallel-drafting EAGLE-3) training logic.
Split out of core.py so the P-EAGLE trainer evolves independently of the
EAGLE-3 test-time-training trainer. The shared step-metrics container
(:class:Eagle3StepMetrics) still lives in core.py and is imported here.
Module Contents
Classes
Functions
API
Bases: Module
Draft-side P-EAGLE (parallel-drafting EAGLE-3) trainer module.
Faithful port of speculators’ P-EAGLE
(https://github.com/vllm-project/speculators/pull/480): the draft predicts
all num_depths tokens in a single parallel forward over a flat,
COD-subsampled sequence — it does NOT run EAGLE-3’s autoregressive TTT
recurrence.
Per training sequence:
- COD sampling (:func:
generate_cod_sample_indices) draws(anchor_pos, depth): depth 0 keeps every position, depthdkeeps a geometrically decayingdown_sample_ratio**dfraction. The reference position of each element isanchor_pos + depth. - Flat input assembly. All depths are concatenated into one
[1, total_sampled]sequence. Depth-0 slots take the real token id and thefc-projected target aux hidden state; depth >= 1 slots take the maskedmask_token_idand the single learnablemask_hiddenplaceholder (projected through the samefc). - COD flex attention. A single
flex_attentionforward with the :func:create_peagle_mask_modblock mask: each element attends to the causal depth-0 context of its document plus earlier-or-equal depths of its own rollout. This is exactly what vLLM’s parallel-drafting runtime sees at inference. - Count-normalized KL loss.
KL(target || draft)over the draft vocab at every supervised sampled position, normalized by a single total token count — deeper depths (fewer COD positions) naturally contribute less gradient. No0.8**dschedule.
Batches with batch_size > 1 are processed row-by-row (speculators is
batch-size-1); per-row losses are accumulated with a shared denominator so
the normalization stays count-based across the whole batch.
Sequence partitioning (sequence_partitions > 1). The flat COD
forward attends over n * sum(r**d) positions, so its peak attention /
activation memory grows with the context length and OOMs on long sequences.
P-EAGLE’s Algorithm 1 (arXiv:2602.01469) splits each sequence into S
segments by dependency lineage (:func:assign_cod_segments) and runs a
separate forward+backward per segment so only one segment’s activations are
resident at a time. The partition is exact: each segment additionally reads
every depth-0 position up to its right boundary as key/value context (causal
completion), so a segment’s queries see exactly the key/value set they would
in the single flat forward — the gradients accumulated across segments equal
the single-forward gradient.
The split is caller-driven so the gradient sync stays correct under DDP:
:meth:build_peagle_plan (no-grad) assigns COD elements to segments, then the
recipe runs one forward(..., peagle_segment=(plan, i)) per segment and
owns the backward(). Doing the per-segment backward here (inside a single
forward) would bypass DistributedDataParallel’s reducer — its grad
all-reduce hooks only fire for backwards over the tensor DDP.forward
returns — and silently desynchronize ranks. sequence_partitions == 1 and
eval take the single flat forward unchanged.
Loss for one segment of a :meth:build_peagle_plan plan.
The recipe drives this once per plan.units entry and back-propagates
each result, so each segment owns a self-contained autograd graph that is
freed before the next — and the backward flows through DDP.forward so
gradients all-reduce correctly. metrics.loss is the segment’s share of
the count-normalized batch loss (loss / total_den); summing it over the
plan reproduces the single flat forward’s loss.
Draft forward + count-normalized KL for one row’s COD elements.
Shared by the single flat forward (one call per row, all sampled
positions charged) and the partitioned forward (one call per segment,
only the segment’s owned/supervised positions charged — the rest ride
along as key/value context). Returns (loss_num, loss_den, correct, valid) as float scalars, where the loss is Σ KL over loss_positions
and loss_den is their count; the caller normalizes.
Assign COD elements to segments for the sequence-partitioning path.
Samples COD once per row (the indices must be reused across the segment
forwards), runs Algorithm 1 assignment (:func:assign_cod_segments) plus
causal completion, and emits one unit per non-empty (row, segment)
as (b, anchor, depth, orig_positions, loss_positions). loss_positions
marks the segment’s owned supervised elements (charged loss); the other
elements are depth-0 causal-completion context (key/value only). The shared
total_den is the batch’s total supervised-token count, so each segment’s
loss / total_den sums to the single-forward loss.
Run the P-EAGLE parallel-drafting loss for one batch.
attention_mask supplies the per-row valid length so padded positions
are excluded from attention (document mask) and from supervision.
peagle_segment selects the sequence-partitioning path: when it is a
(plan, index) pair (built by :meth:build_peagle_plan) this computes
the loss for that one segment only — the recipe calls this once per
segment and owns the backward() so DDP’s gradient sync stays correct.
When None (sequence_partitions == 1 and eval) a single flat
forward over the whole COD sequence returns a grad-carrying loss.
Sequence-partitioning plan: one unit per non-empty (row, segment).
A plain (non-container) class on purpose: DistributedDataParallel only
scatters tensors and built-in containers across its inputs, so passing this
object through forward(peagle_segment=(plan, i)) leaves its tensors intact
(a dict/tuple of tensors would be sliced along dim 0 by DDP’s scatter).
Per-position KL(target || draft) over the draft vocabulary.
Matches speculators’ kl_div_loss: log_softmax the draft logits,
softmax the target logits, and sum the elementwise KL over the vocab
axis. Shapes [*, draft_vocab] -> [*].