nemo_automodel.components.speculative.eagle.peagle_trainer

View as Markdown

P-EAGLE (parallel-drafting EAGLE-3) training logic.

Split out of core.py so the P-EAGLE trainer evolves independently of the EAGLE-3 test-time-training trainer. The shared step-metrics container (:class:Eagle3StepMetrics) still lives in core.py and is imported here.

Module Contents

Classes

NameDescription
PEagleTrainerModuleDraft-side P-EAGLE (parallel-drafting EAGLE-3) trainer module.
_PeaglePlanSequence-partitioning plan: one unit per non-empty (row, segment).

Functions

NameDescription
_kl_div_lossPer-position KL(target || draft) over the draft vocabulary.

API

class nemo_automodel.components.speculative.eagle.peagle_trainer.PEagleTrainerModule(
draft_model: torch.nn.Module,
selected_token_ids: torch.Tensor,
selected_token_mask: torch.Tensor,
num_depths: int,
mask_token_id: int,
down_sample_ratio: float = 0.7,
down_sample_ratio_min: float = 0.2,
sequence_partitions: int = 1
)

Bases: Module

Draft-side P-EAGLE (parallel-drafting EAGLE-3) trainer module.

Faithful port of speculators’ P-EAGLE (https://github.com/vllm-project/speculators/pull/480): the draft predicts all num_depths tokens in a single parallel forward over a flat, COD-subsampled sequence — it does NOT run EAGLE-3’s autoregressive TTT recurrence.

Per training sequence:

  1. COD sampling (:func:generate_cod_sample_indices) draws (anchor_pos, depth): depth 0 keeps every position, depth d keeps a geometrically decaying down_sample_ratio**d fraction. The reference position of each element is anchor_pos + depth.
  2. Flat input assembly. All depths are concatenated into one [1, total_sampled] sequence. Depth-0 slots take the real token id and the fc-projected target aux hidden state; depth >= 1 slots take the masked mask_token_id and the single learnable mask_hidden placeholder (projected through the same fc).
  3. COD flex attention. A single flex_attention forward with the :func:create_peagle_mask_mod block mask: each element attends to the causal depth-0 context of its document plus earlier-or-equal depths of its own rollout. This is exactly what vLLM’s parallel-drafting runtime sees at inference.
  4. Count-normalized KL loss. KL(target || draft) over the draft vocab at every supervised sampled position, normalized by a single total token count — deeper depths (fewer COD positions) naturally contribute less gradient. No 0.8**d schedule.

Batches with batch_size > 1 are processed row-by-row (speculators is batch-size-1); per-row losses are accumulated with a shared denominator so the normalization stays count-based across the whole batch.

Sequence partitioning (sequence_partitions > 1). The flat COD forward attends over n * sum(r**d) positions, so its peak attention / activation memory grows with the context length and OOMs on long sequences. P-EAGLE’s Algorithm 1 (arXiv:2602.01469) splits each sequence into S segments by dependency lineage (:func:assign_cod_segments) and runs a separate forward+backward per segment so only one segment’s activations are resident at a time. The partition is exact: each segment additionally reads every depth-0 position up to its right boundary as key/value context (causal completion), so a segment’s queries see exactly the key/value set they would in the single flat forward — the gradients accumulated across segments equal the single-forward gradient.

The split is caller-driven so the gradient sync stays correct under DDP: :meth:build_peagle_plan (no-grad) assigns COD elements to segments, then the recipe runs one forward(..., peagle_segment=(plan, i)) per segment and owns the backward(). Doing the per-segment backward here (inside a single forward) would bypass DistributedDataParallel’s reducer — its grad all-reduce hooks only fire for backwards over the tensor DDP.forward returns — and silently desynchronize ranks. sequence_partitions == 1 and eval take the single flat forward unchanged.

down_sample_ratio
= float(down_sample_ratio)
down_sample_ratio_min
= float(down_sample_ratio_min)
mask_token_id
= int(mask_token_id)
sequence_partitions
= int(sequence_partitions)
nemo_automodel.components.speculative.eagle.peagle_trainer.PEagleTrainerModule._forward_peagle_segment(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
aux_hidden_states: torch.Tensor,
target_logits: torch.Tensor,
peagle_segment: tuple
) -> nemo_automodel.components.speculative.eagle.core.Eagle3StepMetrics

Loss for one segment of a :meth:build_peagle_plan plan.

The recipe drives this once per plan.units entry and back-propagates each result, so each segment owns a self-contained autograd graph that is freed before the next — and the backward flows through DDP.forward so gradients all-reduce correctly. metrics.loss is the segment’s share of the count-normalized batch loss (loss / total_den); summing it over the plan reproduces the single flat forward’s loss.

nemo_automodel.components.speculative.eagle.peagle_trainer.PEagleTrainerModule._peagle_position_loss(
input_ids_row: torch.Tensor,
aux_row: torch.Tensor,
target_logits_row: torch.Tensor,
anchor_pos: torch.Tensor,
depth: torch.Tensor,
orig_positions: torch.Tensor,
loss_positions: torch.Tensor,
row_length: torch.Tensor,
seq_len: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

Draft forward + count-normalized KL for one row’s COD elements.

Shared by the single flat forward (one call per row, all sampled positions charged) and the partitioned forward (one call per segment, only the segment’s owned/supervised positions charged — the rest ride along as key/value context). Returns (loss_num, loss_den, correct, valid) as float scalars, where the loss is Σ KL over loss_positions and loss_den is their count; the caller normalizes.

nemo_automodel.components.speculative.eagle.peagle_trainer.PEagleTrainerModule.build_peagle_plan(
loss_mask: torch.Tensor
) -> '_PeaglePlan'

Assign COD elements to segments for the sequence-partitioning path.

Samples COD once per row (the indices must be reused across the segment forwards), runs Algorithm 1 assignment (:func:assign_cod_segments) plus causal completion, and emits one unit per non-empty (row, segment) as (b, anchor, depth, orig_positions, loss_positions). loss_positions marks the segment’s owned supervised elements (charged loss); the other elements are depth-0 causal-completion context (key/value only). The shared total_den is the batch’s total supervised-token count, so each segment’s loss / total_den sums to the single-forward loss.

nemo_automodel.components.speculative.eagle.peagle_trainer.PEagleTrainerModule.forward(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
aux_hidden_states: torch.Tensor,
target_logits: torch.Tensor,
peagle_segment: tuple | None = None
) -> nemo_automodel.components.speculative.eagle.core.Eagle3StepMetrics

Run the P-EAGLE parallel-drafting loss for one batch.

attention_mask supplies the per-row valid length so padded positions are excluded from attention (document mask) and from supervision.

peagle_segment selects the sequence-partitioning path: when it is a (plan, index) pair (built by :meth:build_peagle_plan) this computes the loss for that one segment only — the recipe calls this once per segment and owns the backward() so DDP’s gradient sync stays correct. When None (sequence_partitions == 1 and eval) a single flat forward over the whole COD sequence returns a grad-carrying loss.

class nemo_automodel.components.speculative.eagle.peagle_trainer._PeaglePlan(
units: list[tuple],
total_den: torch.Tensor
)

Sequence-partitioning plan: one unit per non-empty (row, segment).

A plain (non-container) class on purpose: DistributedDataParallel only scatters tensors and built-in containers across its inputs, so passing this object through forward(peagle_segment=(plan, i)) leaves its tensors intact (a dict/tuple of tensors would be sliced along dim 0 by DDP’s scatter).

__slots__
= ('units', 'total_den')
nemo_automodel.components.speculative.eagle.peagle_trainer._kl_div_loss(
logits: torch.Tensor,
target_logits: torch.Tensor
) -> torch.Tensor

Per-position KL(target || draft) over the draft vocabulary.

Matches speculators’ kl_div_loss: log_softmax the draft logits, softmax the target logits, and sum the elementwise KL over the vocab axis. Shapes [*, draft_vocab] -> [*].