nemo_automodel.components.speculative.eagle.target_v12#

Target-model wrapper for EAGLE-1 / EAGLE-2 training.

Module Contents#

Classes#

EagleTargetBatch

Target-model outputs needed by the EAGLE-1 / EAGLE-2 trainer.

HFEagleTargetModel

Thin wrapper that exposes hidden-state supervision from a causal LM.

Functions#

_shift_left_with_zero

Shift a batched sequence tensor left and zero-fill the tail.

API#

nemo_automodel.components.speculative.eagle.target_v12._shift_left_with_zero(tensor: torch.Tensor) torch.Tensor#

Shift a batched sequence tensor left and zero-fill the tail.

class nemo_automodel.components.speculative.eagle.target_v12.EagleTargetBatch#

Target-model outputs needed by the EAGLE-1 / EAGLE-2 trainer.

input_hidden_states: torch.Tensor#

None

target_hidden_states: torch.Tensor#

None

target_logits: torch.Tensor#

None

input_ids: torch.Tensor#

None

attention_mask: torch.Tensor#

None

loss_mask: torch.Tensor#

None

class nemo_automodel.components.speculative.eagle.target_v12.HFEagleTargetModel(model: torch.nn.Module)#

Thin wrapper that exposes hidden-state supervision from a causal LM.

Initialization

get_input_embeddings() torch.nn.Embedding#

Return the target model input embeddings.

get_lm_head() torch.nn.Module#

Return the target model lm_head.

generate_batch(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
) nemo_automodel.components.speculative.eagle.target_v12.EagleTargetBatch#

Run the target transformer and prepare shifted supervision tensors.