nemo_automodel.components.speculative.eagle.core_v12#
Core EAGLE-1 / EAGLE-2 draft-training logic.
Module Contents#
Classes#
Aggregated metrics from one EAGLE-1 / EAGLE-2 training step. |
|
Draft-side trainer for EAGLE-1 / EAGLE-2 hidden-state prediction. |
API#
- class nemo_automodel.components.speculative.eagle.core_v12.EagleStepMetrics#
Aggregated metrics from one EAGLE-1 / EAGLE-2 training step.
- loss: torch.Tensor#
None
None
- token_loss: torch.Tensor#
None
- accuracy: torch.Tensor#
None
- valid_tokens: torch.Tensor#
None
- class nemo_automodel.components.speculative.eagle.core_v12.EagleTrainerModule(
- draft_model: torch.nn.Module,
- *,
- target_lm_head: torch.nn.Module,
- hidden_loss_weight: float = 1.0,
- token_loss_weight: float = 0.1,
Bases:
torch.nn.ModuleDraft-side trainer for EAGLE-1 / EAGLE-2 hidden-state prediction.
Initialization
- compute_logits(hidden_states: torch.Tensor) torch.Tensor#
Project predicted hidden states through the frozen target lm_head.
- forward(
- input_ids: torch.Tensor,
- attention_mask: torch.Tensor,
- loss_mask: torch.Tensor,
- input_hidden_states: torch.Tensor,
- target_hidden_states: torch.Tensor,
- target_logits: torch.Tensor,
Run one EAGLE-1 / EAGLE-2 training step.