nemo_automodel.components.speculative.eagle.core
nemo_automodel.components.speculative.eagle.core
Core EAGLE-3 training logic for the minimal Llama MVP.
Module Contents
Classes
Functions
API
Aggregated metrics from one EAGLE-3 training step.
Bases: Module
Draft-side EAGLE-3 trainer module with test-time-training unroll.
Run the EAGLE-3 unrolled draft loss for one batch.
The attention layer is driven through a shared cache_hidden
list so each TTT step can attend to the K/V branches produced by
every previous step at the same position. This matches the
SpecForge llama3_eagle.py recurrence; without it, multi-step
TTT would degenerate into ttt_steps independent single-step
passes and the draft would never learn the multi-step
distribution it sees at deployment time.
attention_mask is held constant across TTT steps — only
input_ids / loss_mask / position_mask /
target_probs roll forward by one position per step.
Packing: position_ids / seq_lens make the draft’s Block-1 attention
document-level block-causal, and doc_remaining gates supervision per
step (slot t valid at step k only while k < doc_remaining[t]),
masking every cross-document TTT prediction.
Two supervision sources are accepted: the live path passes the
target’s full-vocab target_logits and the draft distribution is
derived here; the offline-cache path (precompute_eagle3) passes the
already-derived target_probs (over the draft vocab) and
position_mask directly, so the full-vocab logits never have to be
stored. Provide exactly one of the two.
Project target logits into draft vocabulary space and build supervision mask.
Shift a batched sequence tensor left and zero-fill the tail.