nemo_automodel.components.speculative.eagle.target_v12

View as Markdown

Target-model wrapper for EAGLE-1 / EAGLE-2 training.

Module Contents

Classes

NameDescription
EagleTargetBatchTarget-model outputs needed by the EAGLE-1 / EAGLE-2 trainer.
HFEagleTargetModelThin wrapper that exposes hidden-state supervision from a causal LM.

Functions

NameDescription
_shift_left_with_zeroShift a batched sequence tensor left and zero-fill the tail.

API

class nemo_automodel.components.speculative.eagle.target_v12.EagleTargetBatch(
input_hidden_states: torch.Tensor,
target_hidden_states: torch.Tensor,
target_logits: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor
)
Dataclass

Target-model outputs needed by the EAGLE-1 / EAGLE-2 trainer.

attention_mask
Tensor
input_hidden_states
Tensor
input_ids
Tensor
loss_mask
Tensor
target_hidden_states
Tensor
target_logits
Tensor
class nemo_automodel.components.speculative.eagle.target_v12.HFEagleTargetModel(
model: torch.nn.Module
)

Thin wrapper that exposes hidden-state supervision from a causal LM.

model
= model.eval()
nemo_automodel.components.speculative.eagle.target_v12.HFEagleTargetModel.generate_batch(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor
) -> nemo_automodel.components.speculative.eagle.target_v12.EagleTargetBatch

Run the target transformer and prepare shifted supervision tensors.

nemo_automodel.components.speculative.eagle.target_v12.HFEagleTargetModel.get_input_embeddings() -> torch.nn.Embedding

Return the target model input embeddings.

nemo_automodel.components.speculative.eagle.target_v12.HFEagleTargetModel.get_lm_head() -> torch.nn.Module

Return the target model lm_head.

nemo_automodel.components.speculative.eagle.target_v12._shift_left_with_zero(
tensor: torch.Tensor
) -> torch.Tensor

Shift a batched sequence tensor left and zero-fill the tail.