nemo_automodel.components.speculative.eagle.draft_llama_v12#

Minimal Llama-based draft model for EAGLE-1 / EAGLE-2 training.

Module Contents#

Classes#

EagleLlamaAttention

Standard Llama-style self attention for the EAGLE-1/2 draft.

EagleLlamaMLP

Standard SwiGLU MLP used by the EAGLE-1/2 draft.

EagleLlamaDecoderLayer

Single decoder layer for the minimal EAGLE-1/2 draft model.

LlamaEagleDraftModel

Minimal Llama draft that predicts next-step hidden states.

Functions#

_build_causal_mask

Build a standard causal + padding mask for eager attention.

API#

nemo_automodel.components.speculative.eagle.draft_llama_v12._build_causal_mask(
attention_mask: torch.Tensor,
dtype: torch.dtype,
) torch.Tensor#

Build a standard causal + padding mask for eager attention.

class nemo_automodel.components.speculative.eagle.draft_llama_v12.EagleLlamaAttention(config: transformers.LlamaConfig)#

Bases: torch.nn.Module

Standard Llama-style self attention for the EAGLE-1/2 draft.

Initialization

_repeat_kv(tensor: torch.Tensor) torch.Tensor#
forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
) torch.Tensor#
class nemo_automodel.components.speculative.eagle.draft_llama_v12.EagleLlamaMLP(config: transformers.LlamaConfig)#

Bases: torch.nn.Module

Standard SwiGLU MLP used by the EAGLE-1/2 draft.

Initialization

forward(hidden_states: torch.Tensor) torch.Tensor#
class nemo_automodel.components.speculative.eagle.draft_llama_v12.EagleLlamaDecoderLayer(config: transformers.LlamaConfig)#

Bases: torch.nn.Module

Single decoder layer for the minimal EAGLE-1/2 draft model.

Initialization

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
) torch.Tensor#
class nemo_automodel.components.speculative.eagle.draft_llama_v12.LlamaEagleDraftModel(config: transformers.LlamaConfig)#

Bases: transformers.PreTrainedModel

Minimal Llama draft that predicts next-step hidden states.

Initialization

config_class#

None

main_input_name#

‘input_ids’

copy_embeddings_from_target(
target_embeddings: torch.nn.Embedding,
) None#

Copy the target model token embeddings into the draft embeddings.

freeze_embeddings() None#

Freeze draft token embeddings.

forward(
input_ids: torch.Tensor,
target_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) torch.Tensor#