nemo_automodel.components.speculative.eagle.core

View as Markdown

Core EAGLE-3 training logic for the minimal Llama MVP.

Module Contents

Classes

NameDescription
Eagle3StepMetricsAggregated metrics from one EAGLE-3 training step.
Eagle3TrainerModuleDraft-side EAGLE-3 trainer module with test-time-training unroll.

Functions

NameDescription
_compute_target_distributionProject target logits into draft vocabulary space and build supervision mask.
_shift_left_with_zeroShift a batched sequence tensor left and zero-fill the tail.

API

class nemo_automodel.components.speculative.eagle.core.Eagle3StepMetrics(
loss: torch.Tensor,
accuracy: torch.Tensor,
valid_tokens: torch.Tensor
)
Dataclass

Aggregated metrics from one EAGLE-3 training step.

accuracy
Tensor
loss
Tensor
valid_tokens
Tensor
class nemo_automodel.components.speculative.eagle.core.Eagle3TrainerModule(
draft_model: torch.nn.Module,
selected_token_ids: torch.Tensor,
selected_token_mask: torch.Tensor,
ttt_steps: int
)

Bases: Module

Draft-side EAGLE-3 trainer module with test-time-training unroll.

nemo_automodel.components.speculative.eagle.core.Eagle3TrainerModule.forward(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
aux_hidden_states: torch.Tensor,
target_logits: torch.Tensor | None = None,
target_probs: torch.Tensor | None = None,
position_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
seq_lens: torch.Tensor | None = None,
doc_remaining: torch.Tensor | None = None
) -> nemo_automodel.components.speculative.eagle.core.Eagle3StepMetrics

Run the EAGLE-3 unrolled draft loss for one batch.

The attention layer is driven through a shared cache_hidden list so each TTT step can attend to the K/V branches produced by every previous step at the same position. This matches the SpecForge llama3_eagle.py recurrence; without it, multi-step TTT would degenerate into ttt_steps independent single-step passes and the draft would never learn the multi-step distribution it sees at deployment time.

attention_mask is held constant across TTT steps — only input_ids / loss_mask / position_mask / target_probs roll forward by one position per step.

Packing: position_ids / seq_lens make the draft’s Block-1 attention document-level block-causal, and doc_remaining gates supervision per step (slot t valid at step k only while k < doc_remaining[t]), masking every cross-document TTT prediction.

Two supervision sources are accepted: the live path passes the target’s full-vocab target_logits and the draft distribution is derived here; the offline-cache path (precompute_eagle3) passes the already-derived target_probs (over the draft vocab) and position_mask directly, so the full-vocab logits never have to be stored. Provide exactly one of the two.

nemo_automodel.components.speculative.eagle.core._compute_target_distribution(
target_logits: torch.Tensor,
selected_token_ids: torch.Tensor,
selected_token_mask: torch.Tensor,
loss_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]

Project target logits into draft vocabulary space and build supervision mask.

nemo_automodel.components.speculative.eagle.core._shift_left_with_zero(
tensor: torch.Tensor
) -> torch.Tensor

Shift a batched sequence tensor left and zero-fill the tail.