nemo_automodel.components.speculative.eagle.target

View as Markdown

Target-model wrapper for minimal EAGLE-3 training.

Module Contents

Classes

NameDescription
Eagle3TargetBatchTarget-model supervision for one draft-training batch.
HFEagle3TargetModelCo-located backend that captures three auxiliary hidden states from a causal LM.

Functions

NameDescription
_shift_left_with_zeroShift a batched sequence tensor left and zero-fill the tail.

API

class nemo_automodel.components.speculative.eagle.target.Eagle3TargetBatch(
aux_hidden_states: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
logits: torch.Tensor | None = None,
target_probs: torch.Tensor | None = None,
position_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
seq_lens: torch.Tensor | None = None,
doc_remaining: torch.Tensor | None = None
)
Dataclass

Target-model supervision for one draft-training batch.

Carries exactly one supervision encoding (validated in __post_init__), both consumed directly by Eagle3TrainerModule.forward:

  • logits — the target’s full-vocab logits; the draft-vocab projection happens trainer-side. Used by the co-located backend, where the tensor never leaves the GPU.
  • target_probs + position_mask — the already-projected draft-vocab distribution, so a backend that computes it itself (e.g. a remote server) only transfers draft-vocab-sized tensors.
attention_mask
Tensor
aux_hidden_states
Tensor
doc_remaining
Tensor | None = None
input_ids
Tensor
logits
Tensor | None = None
loss_mask
Tensor
position_ids
Tensor | None = None
position_mask
Tensor | None = None
seq_lens
Tensor | None = None
target_probs
Tensor | None = None
nemo_automodel.components.speculative.eagle.target.Eagle3TargetBatch.__post_init__() -> None
nemo_automodel.components.speculative.eagle.target.Eagle3TargetBatch.to_trainer_inputs() -> dict[str, torch.Tensor]

Return kwargs for Eagle3TrainerModule.forward, dispatching on whichever supervision encoding this batch carries.

class nemo_automodel.components.speculative.eagle.target.HFEagle3TargetModel(
model: torch.nn.Module,
aux_layer_ids: typing.Sequence[int] | None = None,
cp_mesh = None
)

Bases: Eagle3TargetBackend

Co-located backend that captures three auxiliary hidden states from a causal LM.

_cp_size
= cp_mesh.size() if cp_mesh is not None else 1
aux_layer_ids
= self._validate_aux_layer_ids(candidate_ids)
model
= model.eval()
nemo_automodel.components.speculative.eagle.target.HFEagle3TargetModel._check_captured(
captured: dict[int, torch.Tensor]
) -> None
nemo_automodel.components.speculative.eagle.target.HFEagle3TargetModel._default_aux_layer_ids() -> list[int]
nemo_automodel.components.speculative.eagle.target.HFEagle3TargetModel._get_transformer_layers() -> list[torch.nn.Module]

Return decoder layers as an ordered list indexable by integer.

Supports both the HuggingFace layouts (where layers is a ModuleList) and AutoModel’s custom-impl layouts (where layers is a ModuleDict keyed by str(i)). Returning a plain list normalizes the access pattern for downstream register_forward_hook calls.

nemo_automodel.components.speculative.eagle.target.HFEagle3TargetModel._validate_aux_layer_ids(
aux_layer_ids: typing.Sequence[int]
) -> list[int]

Validate aux-layer selection before any forward hooks are registered.

nemo_automodel.components.speculative.eagle.target.HFEagle3TargetModel.generate_batch(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
position_ids: torch.Tensor | None = None,
seq_lens: torch.Tensor | None = None,
doc_remaining: torch.Tensor | None = None
) -> nemo_automodel.components.speculative.eagle.target.Eagle3TargetBatch

Run the target model and capture aux hidden states plus logits.

With seq_lens (packing), the target runs with a [B, 1, T, T] block-causal mask and per-document position_ids so its outputs respect document boundaries; the packing metadata is forwarded unshifted to the trainer. seq_lens=None keeps the original 2D-mask path.

nemo_automodel.components.speculative.eagle.target.HFEagle3TargetModel.get_input_embeddings() -> torch.nn.Embedding

Return the target model input embeddings.

nemo_automodel.components.speculative.eagle.target._shift_left_with_zero(
tensor: torch.Tensor
) -> torch.Tensor

Shift a batched sequence tensor left and zero-fill the tail.

This matches the reference EAGLE-3 target preparation used by SpecForge: sequence-aligned tensors are shifted with padding(..., left=False). See SpecForge eagle3_target_model.py around the target preparation logic referenced by the user.