nemo_automodel.components.speculative.eagle.target#

Target-model wrapper for minimal EAGLE-3 training.

Module Contents#

Classes#

Eagle3TargetBatch

Target-model outputs needed by the draft trainer.

HFEagle3TargetModel

Thin wrapper that captures three auxiliary hidden states from a causal LM.

Functions#

_shift_left_with_zero

Shift a batched sequence tensor left and zero-fill the tail.

API#

nemo_automodel.components.speculative.eagle.target._shift_left_with_zero(tensor: torch.Tensor) torch.Tensor#

Shift a batched sequence tensor left and zero-fill the tail.

This matches the reference EAGLE-3 target preparation used by SpecForge: sequence-aligned tensors are shifted with padding(..., left=False). See SpecForge eagle3_target_model.py around the target preparation logic referenced by the user.

class nemo_automodel.components.speculative.eagle.target.Eagle3TargetBatch#

Target-model outputs needed by the draft trainer.

aux_hidden_states: torch.Tensor#

None

logits: torch.Tensor#

None

input_ids: torch.Tensor#

None

attention_mask: torch.Tensor#

None

loss_mask: torch.Tensor#

None

class nemo_automodel.components.speculative.eagle.target.HFEagle3TargetModel(
model: torch.nn.Module,
aux_layer_ids: Sequence[int] | None = None,
)#

Thin wrapper that captures three auxiliary hidden states from a causal LM.

Initialization

_default_aux_layer_ids() list[int]#
_validate_aux_layer_ids(
aux_layer_ids: Sequence[int],
) list[int]#

Validate aux-layer selection before any forward hooks are registered.

_get_transformer_layers() Iterable[torch.nn.Module]#
get_input_embeddings() torch.nn.Embedding#

Return the target model input embeddings.

generate_batch(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
) nemo_automodel.components.speculative.eagle.target.Eagle3TargetBatch#

Run the target model and capture aux hidden states plus logits.