nemo_automodel.components.speculative.eagle.target#
Target-model wrapper for minimal EAGLE-3 training.
Module Contents#
Classes#
Target-model outputs needed by the draft trainer. |
|
Thin wrapper that captures three auxiliary hidden states from a causal LM. |
Functions#
Shift a batched sequence tensor left and zero-fill the tail. |
API#
- nemo_automodel.components.speculative.eagle.target._shift_left_with_zero(tensor: torch.Tensor) torch.Tensor#
Shift a batched sequence tensor left and zero-fill the tail.
This matches the reference EAGLE-3 target preparation used by SpecForge: sequence-aligned tensors are shifted with
padding(..., left=False). See SpecForgeeagle3_target_model.pyaround the target preparation logic referenced by the user.
- class nemo_automodel.components.speculative.eagle.target.Eagle3TargetBatch#
Target-model outputs needed by the draft trainer.
None
- logits: torch.Tensor#
None
- input_ids: torch.Tensor#
None
- attention_mask: torch.Tensor#
None
- loss_mask: torch.Tensor#
None
- class nemo_automodel.components.speculative.eagle.target.HFEagle3TargetModel(
- model: torch.nn.Module,
- aux_layer_ids: Sequence[int] | None = None,
Thin wrapper that captures three auxiliary hidden states from a causal LM.
Initialization
- _default_aux_layer_ids() list[int]#
- _validate_aux_layer_ids(
- aux_layer_ids: Sequence[int],
Validate aux-layer selection before any forward hooks are registered.
- _get_transformer_layers() Iterable[torch.nn.Module]#
- get_input_embeddings() torch.nn.Embedding#
Return the target model input embeddings.
- generate_batch(
- input_ids: torch.Tensor,
- attention_mask: torch.Tensor,
- loss_mask: torch.Tensor,
Run the target model and capture aux hidden states plus logits.