nemo_automodel.components.speculative.eagle.core_v12#

Core EAGLE-1 / EAGLE-2 draft-training logic.

Module Contents#

Classes#

EagleStepMetrics

Aggregated metrics from one EAGLE-1 / EAGLE-2 training step.

EagleTrainerModule

Draft-side trainer for EAGLE-1 / EAGLE-2 hidden-state prediction.

API#

class nemo_automodel.components.speculative.eagle.core_v12.EagleStepMetrics#

Aggregated metrics from one EAGLE-1 / EAGLE-2 training step.

loss: torch.Tensor#

None

hidden_loss: torch.Tensor#

None

token_loss: torch.Tensor#

None

accuracy: torch.Tensor#

None

valid_tokens: torch.Tensor#

None

class nemo_automodel.components.speculative.eagle.core_v12.EagleTrainerModule(
draft_model: torch.nn.Module,
*,
target_lm_head: torch.nn.Module,
hidden_loss_weight: float = 1.0,
token_loss_weight: float = 0.1,
)#

Bases: torch.nn.Module

Draft-side trainer for EAGLE-1 / EAGLE-2 hidden-state prediction.

Initialization

compute_logits(hidden_states: torch.Tensor) torch.Tensor#

Project predicted hidden states through the frozen target lm_head.

forward(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
input_hidden_states: torch.Tensor,
target_hidden_states: torch.Tensor,
target_logits: torch.Tensor,
) nemo_automodel.components.speculative.eagle.core_v12.EagleStepMetrics#

Run one EAGLE-1 / EAGLE-2 training step.