nemo_automodel.components.speculative.eagle.draft_llama_v12

View as Markdown

Llama-style dense LLM draft model for EAGLE-1 / EAGLE-2 training.

Config-driven; supports Llama, Phi-3, and Qwen3 dense via standard HF config fields (attention_bias, mlp_bias, rope_theta/rope_scaling, rms_norm_eps). Class names are retained for checkpoint-architectures compatibility.

Module Contents

Classes

NameDescription
EagleLlamaAttentionStandard Llama-style self attention for the EAGLE-1/2 draft.
EagleLlamaDecoderLayerSingle decoder layer for the minimal EAGLE-1/2 draft model.
EagleLlamaMLPStandard SwiGLU MLP used by the EAGLE-1/2 draft.
LlamaEagleDraftModelLlama-style dense draft that predicts next-step hidden states.

Functions

NameDescription
_build_causal_maskBuild a standard causal + padding mask for eager attention.

API

class nemo_automodel.components.speculative.eagle.draft_llama_v12.EagleLlamaAttention(
config: transformers.PretrainedConfig
)

Bases: Module

Standard Llama-style self attention for the EAGLE-1/2 draft.

head_dim
k_proj
num_heads
= config.num_attention_heads
num_key_value_groups
= self.num_heads // self.num_key_value_heads
num_key_value_heads
= config.num_key_value_heads
o_proj
q_proj
rotary_emb
= LlamaRotaryEmbedding(config)
scaling
= self.head_dim ** -0.5
v_proj
nemo_automodel.components.speculative.eagle.draft_llama_v12.EagleLlamaAttention._repeat_kv(
tensor: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.speculative.eagle.draft_llama_v12.EagleLlamaAttention.forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor
) -> torch.Tensor
class nemo_automodel.components.speculative.eagle.draft_llama_v12.EagleLlamaDecoderLayer(
config: transformers.PretrainedConfig
)

Bases: Module

Single decoder layer for the minimal EAGLE-1/2 draft model.

input_layernorm
mlp
= EagleLlamaMLP(config)
post_attention_layernorm
self_attn
= EagleLlamaAttention(config)
nemo_automodel.components.speculative.eagle.draft_llama_v12.EagleLlamaDecoderLayer.forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor
) -> torch.Tensor
class nemo_automodel.components.speculative.eagle.draft_llama_v12.EagleLlamaMLP(
config: transformers.PretrainedConfig
)

Bases: Module

Standard SwiGLU MLP used by the EAGLE-1/2 draft.

act_fn
= ACT2FN[config.hidden_act]
down_proj
gate_proj
up_proj
nemo_automodel.components.speculative.eagle.draft_llama_v12.EagleLlamaMLP.forward(
hidden_states: torch.Tensor
) -> torch.Tensor
class nemo_automodel.components.speculative.eagle.draft_llama_v12.LlamaEagleDraftModel(
config: transformers.PretrainedConfig
)

Bases: PreTrainedModel

Llama-style dense draft that predicts next-step hidden states.

Works with Llama, Phi-3, and Qwen3 dense configs. The class name is retained for backward compatibility with already-trained checkpoints.

embed_tokens
fc
layers
main_input_name
= 'input_ids'
norm
nemo_automodel.components.speculative.eagle.draft_llama_v12.LlamaEagleDraftModel.copy_embeddings_from_target(
target_embeddings: torch.nn.Embedding
) -> None

Copy the target model token embeddings into the draft embeddings.

When the target is wrapped with FSDP2, its embed_tokens.weight is a DTensor sharded across ranks. Gather to a local full tensor before copying into the (unsharded) draft parameter — otherwise aten.copy_ raises a mixed Tensor/DTensor error.

nemo_automodel.components.speculative.eagle.draft_llama_v12.LlamaEagleDraftModel.forward(
input_ids: torch.Tensor,
target_hidden_states: torch.Tensor,
attention_mask: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.speculative.eagle.draft_llama_v12.LlamaEagleDraftModel.freeze_embeddings() -> None

Freeze draft token embeddings.

nemo_automodel.components.speculative.eagle.draft_llama_v12._build_causal_mask(
attention_mask: torch.Tensor,
dtype: torch.dtype
) -> torch.Tensor

Build a standard causal + padding mask for eager attention.