nemo_automodel.components.speculative.eagle.core#

Core EAGLE-3 training logic for the minimal Llama MVP.

Module Contents#

Classes#

Eagle3StepMetrics

Aggregated metrics from one EAGLE-3 training step.

Eagle3TrainerModule

Draft-side EAGLE-3 trainer module with test-time-training unroll.

Functions#

_shift_left_with_zero

Shift a batched sequence tensor left and zero-fill the tail.

_compute_target_distribution

Project target logits into draft vocabulary space and build supervision mask.

API#

nemo_automodel.components.speculative.eagle.core._shift_left_with_zero(tensor: torch.Tensor) torch.Tensor#

Shift a batched sequence tensor left and zero-fill the tail.

nemo_automodel.components.speculative.eagle.core._compute_target_distribution(
target_logits: torch.Tensor,
selected_token_ids: torch.Tensor,
selected_token_mask: torch.Tensor,
loss_mask: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor]#

Project target logits into draft vocabulary space and build supervision mask.

class nemo_automodel.components.speculative.eagle.core.Eagle3StepMetrics#

Aggregated metrics from one EAGLE-3 training step.

loss: torch.Tensor#

None

accuracy: torch.Tensor#

None

valid_tokens: torch.Tensor#

None

class nemo_automodel.components.speculative.eagle.core.Eagle3TrainerModule(
draft_model: torch.nn.Module,
*,
selected_token_ids: torch.Tensor,
selected_token_mask: torch.Tensor,
ttt_steps: int,
)#

Bases: torch.nn.Module

Draft-side EAGLE-3 trainer module with test-time-training unroll.

Initialization

forward(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
aux_hidden_states: torch.Tensor,
target_logits: torch.Tensor,
) nemo_automodel.components.speculative.eagle.core.Eagle3StepMetrics#

Run the EAGLE-3 unrolled draft loss for one batch.

The attention layer is driven through a shared cache_hidden list so each TTT step can attend to the K/V branches produced by every previous step at the same position. This matches the SpecForge llama3_eagle.py recurrence; without it, multi-step TTT would degenerate into ttt_steps independent single-step passes and the draft would never learn the multi-step distribution it sees at deployment time.

attention_mask is held constant across TTT steps – only input_ids / loss_mask / position_mask / target_probs roll forward by one position per step.