nemo_automodel.components.speculative.eagle.core_v12

View as Markdown

Core EAGLE-1 / EAGLE-2 draft-training logic.

Module Contents

Classes

NameDescription
EagleStepMetricsAggregated metrics from one EAGLE-1 / EAGLE-2 training step.
EagleTrainerModuleDraft-side trainer for EAGLE-1 / EAGLE-2 hidden-state prediction.

API

class nemo_automodel.components.speculative.eagle.core_v12.EagleStepMetrics(
loss: torch.Tensor,
hidden_loss: torch.Tensor,
token_loss: torch.Tensor,
accuracy: torch.Tensor,
valid_tokens: torch.Tensor
)
Dataclass

Aggregated metrics from one EAGLE-1 / EAGLE-2 training step.

accuracy
Tensor
hidden_loss
Tensor
loss
Tensor
token_loss
Tensor
valid_tokens
Tensor
class nemo_automodel.components.speculative.eagle.core_v12.EagleTrainerModule(
draft_model: torch.nn.Module,
target_lm_head: torch.nn.Module,
hidden_loss_weight: float = 1.0,
token_loss_weight: float = 0.1,
feature_noise: float = 0.0
)

Bases: Module

Draft-side trainer for EAGLE-1 / EAGLE-2 hidden-state prediction.

hidden_loss_fn
= nn.SmoothL1Loss(reduction='none')
nemo_automodel.components.speculative.eagle.core_v12.EagleTrainerModule.compute_logits(
hidden_states: torch.Tensor
) -> torch.Tensor

Project predicted hidden states through the frozen target lm_head.

nemo_automodel.components.speculative.eagle.core_v12.EagleTrainerModule.forward(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
input_hidden_states: torch.Tensor,
target_hidden_states: torch.Tensor,
target_logits: torch.Tensor
) -> nemo_automodel.components.speculative.eagle.core_v12.EagleStepMetrics

Run one EAGLE-1 / EAGLE-2 training step.