nemo_automodel.components.speculative.eagle.core#
Core EAGLE-3 training logic for the minimal Llama MVP.
Module Contents#
Classes#
Aggregated metrics from one EAGLE-3 training step. |
|
Draft-side EAGLE-3 trainer module with test-time-training unroll. |
Functions#
Shift a batched sequence tensor left and zero-fill the tail. |
|
Project target logits into draft vocabulary space and build supervision mask. |
API#
- nemo_automodel.components.speculative.eagle.core._shift_left_with_zero(tensor: torch.Tensor) torch.Tensor#
Shift a batched sequence tensor left and zero-fill the tail.
- nemo_automodel.components.speculative.eagle.core._compute_target_distribution(
- target_logits: torch.Tensor,
- selected_token_ids: torch.Tensor,
- selected_token_mask: torch.Tensor,
- loss_mask: torch.Tensor,
Project target logits into draft vocabulary space and build supervision mask.
- class nemo_automodel.components.speculative.eagle.core.Eagle3StepMetrics#
Aggregated metrics from one EAGLE-3 training step.
- loss: torch.Tensor#
None
- accuracy: torch.Tensor#
None
- valid_tokens: torch.Tensor#
None
- class nemo_automodel.components.speculative.eagle.core.Eagle3TrainerModule(
- draft_model: torch.nn.Module,
- *,
- selected_token_ids: torch.Tensor,
- selected_token_mask: torch.Tensor,
- ttt_steps: int,
Bases:
torch.nn.ModuleDraft-side EAGLE-3 trainer module with test-time-training unroll.
Initialization
- forward(
- input_ids: torch.Tensor,
- attention_mask: torch.Tensor,
- loss_mask: torch.Tensor,
- aux_hidden_states: torch.Tensor,
- target_logits: torch.Tensor,
Run the EAGLE-3 unrolled draft loss for one batch.
The attention layer is driven through a shared
cache_hiddenlist so each TTT step can attend to the K/V branches produced by every previous step at the same position. This matches the SpecForgellama3_eagle.pyrecurrence; without it, multi-step TTT would degenerate intottt_stepsindependent single-step passes and the draft would never learn the multi-step distribution it sees at deployment time.attention_maskis held constant across TTT steps – onlyinput_ids/loss_mask/position_mask/target_probsroll forward by one position per step.