nemo_automodel.components.speculative.eagle.draft_llama

View as Markdown

Llama-style dense LLM draft model for EAGLE-3 / EAGLE-3.1 training.

The implementation is config-driven and supports any HuggingFace dense decoder-only architecture whose layout matches Llama: GQA attention with optional Q/K/V/O bias (config.attention_bias), SwiGLU MLP with optional bias (config.mlp_bias), RMSNorm, and rotary position embeddings parameterized by config.rope_theta / config.rope_scaling. This currently covers Llama, Phi-3, and Qwen3 dense (Phi-3 omits attention_bias / mlp_bias, which the attention and MLP layers already read via getattr(config, "<field>", False); Qwen3 decouples head_dim from hidden_size / num_attention_heads, which the attention layer reads via getattr(config, "head_dim", ...)).

Class names and the public architectures string remain LlamaEagle3* for backward compatibility with already-trained checkpoints and with SGLang’s LlamaForCausalLMEagle3.load_weights (the saved state dict layout is unchanged):

model.embed_tokens.weight model.fc.weight model.layers.0.input_layernorm.weight model.layers.0.hidden_norm.weight model.layers.0.post_attention_layernorm.weight model.layers.0.self_attn.{q,k,v,o}_proj.weight model.layers.0.mlp.{gate,up,down}_proj.weight model.norm.weight lm_head.weight

SGLang merges q_proj/k_proj/v_proj into a single qkv_proj and gate_proj/up_proj into gate_up_proj via its stacked_params_mapping at load time, so the un-fused storage above is the canonical on-disk format.

EAGLE-3.1 introduces two optional drafter-side toggles that together address the “attention drift” failure mode observed when speculation depth grows:

  • config.fc_norm (bool, default False) — when True, an nn.ModuleList of num_aux_hidden_states independent RMSNorms (each of size target_hidden_size) is applied per chunk before the concatenated auxiliary hidden states enter model.fc. The on-disk keys are model.fc_norm.0.weight, model.fc_norm.1.weight, …; the module layout matches vLLM’s EAGLE-3.1 integration in PR https://github.com/vllm-project/vllm/pull/42764 so checkpoints trained here load directly into vLLM / SGLang.
  • config.norm_output (bool, default False) — when True, the existing final RMSNorm (model.norm) is applied to the per-step hidden state returned by forward so that the next TTT step (and the lm_head) consume the post-norm state instead of the raw decoder output. Adds no new parameters.

Both flags default to False so EAGLE-3 checkpoints continue to load and behave identically. Enabling them applies the EAGLE-3.1 drafter toggles to the Llama-style draft used here; the MLA-backbone Kimi K2.6 draft (Eagle3DeepseekV2ForCausalLM in lightseekorg/kimi-k2.6-eagle3.1-mla) is a separate architecture and is not covered by this module.

P-EAGLE (parallel-drafting EAGLE-3) adds one further optional toggle:

  • config.parallel_drafting (bool, default False) — when True, the draft registers a single learnable mask_hidden placeholder of shape [1, 1, num_aux_hidden_states * target_hidden_size] (the pre-fc concatenated-aux dimension) and exposes :meth:LlamaEagle3DraftModel.forward_peagle, a single parallel forward over a flat, COD-subsampled sequence with a flex_attention cross-depth mask (see peagle_attention.py / peagle_data.py). The trainer feeds the mask_hidden placeholder — projected through the same project_hidden_states path as real aux states — at every masked depth (>= 1), together with the masked token config.mask_token_id, so the draft predicts all config.num_depths tokens in one forward instead of autoregressively. The shape, the on-disk key mask_hidden, and the COD config (num_depths / down_sample_ratio / mask_token_id) mirror speculators (https://github.com/vllm-project/speculators/pull/480) so the checkpoint loads into vLLM’s parallel-drafting runtime unchanged. The masked token slot reuses embed_tokens[config.mask_token_id]. SGLang does not serve a P-EAGLE head today (https://github.com/sgl-project/sglang/issues/23171). The flag only ever adds the mask_hidden key, so EAGLE-3 / EAGLE-3.1 checkpoints round-trip unchanged.

Module Contents

Classes

NameDescription
Eagle3LlamaAttentionEAGLE-3 draft attention over [input_emb, hidden] 2H features.
Eagle3LlamaDecoderLayerSingle decoder layer used by the minimal EAGLE-3 draft model.
Eagle3LlamaMLPStandard Llama-style SwiGLU MLP on hidden-size activations.
Eagle3LlamaModelInner backbone matching SGLang’s LlamaModel in llama_eagle3.py.
Eagle3LlamaPeagleLayerVanilla Llama decoder layer for P-EAGLE depths >= 1.
LlamaEagle3DraftModelLlama-style dense EAGLE-3 draft model (Llama, Phi-3, Qwen3).

Functions

NameDescription
_build_causal_maskBuild a standard causal + padding mask for SDPA/eager attention.
_is_right_padded_attention_maskReturn True when each row is a contiguous valid-prefix followed by padding.
_load_flash_attn_funcBest-effort load of flash-attn without breaking eager-only users.
_seq_lens_to_cu_seqlensBuild FlashAttention varlen cu_seqlens (int32) from packed seq_lens.

Data

_SUPPORTED_ATTN_IMPLEMENTATIONS

logger

API

class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention(
config: transformers.PretrainedConfig,
fuse_input: bool = True
)

Bases: _PeagleAttentionMixin, Module

EAGLE-3 draft attention over [input_emb, hidden] 2H features.

Driven through a shared cache_hidden = [K_list, V_list] pair. At step k (0-indexed), with K_list and V_list already holding entries from steps 0..k-1:

  1. step_idx = len(K_list) (equal to k) gives the rotary phase shift, so the draft’s K_k encodes “this is k tokens into the future”. The shifted cos / sin are computed from position_ids + step_idx.

  2. The freshly projected K, V (after GQA expansion) are appended to the cache lists in place.

  3. The attention output is the EAGLE-3 mixed pattern:

    attn_weights = [ Q @ K_0^T / sqrt(d) + mask ] || diag_1 || ... || diag_k

    where diag_i[t] = (Q_t * K_i_t).sum(-1) / sqrt(d). The softmax is taken over the full extended column axis of length T + k. Output is

    out = attn_probs[..., :T] @ V_0 + sum_{i=1..k} attn_probs[..., T+i-1, None] * V_i.

    In English: Q at position t attends to all K_0 positions (the regular T x T causal block), and additionally to the same position t in each previous draft step i >= 1. Implementation-wise we replace SpecForge llama3_eagle.py’s two O(k^2) cat / add Python loops with single vectorized einsum calls.

cache_hidden is mutated in place; callers are responsible for re-initializing it to [[], []] at the start of each training batch.

head_dim
k_proj
num_heads
= config.num_attention_heads
num_key_value_groups
= self.num_heads // self.num_key_value_heads
num_key_value_heads
= config.num_key_value_heads
o_proj
q_proj
rotary_emb
= LlamaRotaryEmbedding(config)
scaling
= self.head_dim ** -0.5
v_proj
nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention._eager_attention_forward(
q: torch.Tensor,
cache_k: list[torch.Tensor],
cache_v: list[torch.Tensor],
attention_mask: torch.Tensor,
step_idx: int,
batch_size: int,
seq_len: int
) -> torch.Tensor
nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention._flash_attention_forward(
q: torch.Tensor,
cache_k: list[torch.Tensor],
cache_v: list[torch.Tensor],
step_idx: int,
batch_size: int,
seq_len: int,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: int | None = None
) -> torch.Tensor

EAGLE-3 attention via FlashAttention-2 for the T x T causal block.

FA2 covers Block 1 (causal attention against K_0) and returns its log-sum-exp. The diagonal Block 2 (cached steps i >= 1) is computed eagerly and merged via the log-space identity lse_full = logaddexp(lse_fa, logsumexp(diag)): the FA output is scaled by exp(lse_fa - lse_full) and each diagonal by exp(diag - lse_full).

With cu_seqlens (packing), Block 1 uses flash_attn_varlen_func for document-level causal attention; the position-wise Block 2 is unchanged.

nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention._flash_block1_varlen(
q_fa: torch.Tensor,
k0_fa: torch.Tensor,
v0_fa: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
batch_size: int,
seq_len: int
) -> tuple[torch.Tensor, torch.Tensor]

Document-level causal Block 1 via flash_attn_varlen_func.

Flattens (B, T, H, D) to varlen (total_tokens, H, D) and reshapes outputs back to [B, H, T, D] / [B, H, T] for the dense-path merge. Note varlen softmax_lse is [H, total_tokens] (head-major), unlike the dense [B, H, T] — hence the explicit reshape + shape check.

nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention._project_qkv(
combined_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]
nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention._repeat_kv(
k: torch.Tensor,
v: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention.forward(
combined_states: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
cache_hidden: list[list[torch.Tensor]],
cu_seqlens: torch.Tensor | None = None,
max_seqlen: int | None = None
) -> torch.Tensor
class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaDecoderLayer(
config: transformers.PretrainedConfig,
layer_id: int = 0
)

Bases: _PeagleDecoderLayerMixin, Module

Single decoder layer used by the minimal EAGLE-3 draft model.

Attribute names mirror SGLang’s LlamaDecoderLayer in sglang/srt/models/llama_eagle3.py: input_layernorm is applied to the per-step token embeddings (embeds in SGLang), hidden_norm is applied to the carried hidden state. is_input_layer is the layer-0 flag that gates the [embeds, hidden] concatenation (always true for our single-layer draft).

hidden_norm
input_layernorm
is_input_layer
= layer_id == 0
mlp
= Eagle3LlamaMLP(config)
post_attention_layernorm
self_attn
= Eagle3LlamaAttention(config)
nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaDecoderLayer.forward(
input_embeds: torch.Tensor,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
cache_hidden: list[list[torch.Tensor]],
cu_seqlens: torch.Tensor | None = None,
max_seqlen: int | None = None
) -> torch.Tensor
class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaMLP(
config: transformers.PretrainedConfig
)

Bases: Module

Standard Llama-style SwiGLU MLP on hidden-size activations.

act_fn
= ACT2FN[config.hidden_act]
down_proj
gate_proj
up_proj
nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaMLP.forward(
hidden_states: torch.Tensor
) -> torch.Tensor
class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaModel(
config: transformers.PretrainedConfig
)

Bases: Module

Inner backbone matching SGLang’s LlamaModel in llama_eagle3.py.

Owns embed_tokens, the fc projection from concatenated target aux hidden states to draft hidden size, the (single-element) draft layers ModuleList, and the final norm. The LlamaEagle3DraftModel wrapper around this module adds the top-level lm_head and the training-facing public API.

embed_tokens
fc
fc_norm
layers
= nn.ModuleList(layers)
norm
class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaPeagleLayer(
config: transformers.PretrainedConfig,
layer_id: int
)

Bases: _PeagleVanillaLayerMixin, Module

Vanilla Llama decoder layer for P-EAGLE depths >= 1.

The EAGLE-3 first layer (:class:Eagle3LlamaDecoderLayer) fuses the token embedding and the projected target hidden state (2H attention input). P-EAGLE stacks num_hidden_layers layers; every layer after the first is a standard Llama block operating on plain hidden states (H), matching speculators’ decoder_layer_class (a vanilla LlamaDecoderLayer). Only the P-EAGLE flex-attention path is implemented (these deeper layers do not participate in the EAGLE-3 cache_hidden TTT recurrence).

input_layernorm
mlp
= Eagle3LlamaMLP(config)
post_attention_layernorm
self_attn
= Eagle3LlamaAttention(config, fuse_input=False)
class nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel(
config: transformers.PretrainedConfig
)

Bases: _PeagleDraftMixin, PreTrainedModel

Llama-style dense EAGLE-3 draft model (Llama, Phi-3, Qwen3).

State dict keys match SGLang’s LlamaForCausalLMEagle3 so the saved checkpoint can be loaded by SGLang’s inference engine without any remapping (SGLang’s load_weights fuses q/k/v_proj into qkv_proj and gate/up_proj into gate_up_proj via its standard stacked_params_mapping).

The class name is retained for checkpoint-architectures compatibility; the implementation is config-driven and works for any HF dense decoder-only config that exposes hidden_size, num_attention_heads, num_key_value_heads, attention_bias, mlp_bias, rope_theta, and rms_norm_eps. A decoupled head_dim is read via getattr(config, "head_dim", ...) in the attention layer.

Scope:

  • single draft decoder layer
  • no KV-cache optimization
  • no speculative runtime integration
base_model_prefix
= 'model'
draft_vocab_size
has_vocab_compression
= self.draft_vocab_size < config.vocab_size
lm_head
model
= Eagle3LlamaModel(config)
target_hidden_size
nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.compute_logits(
hidden_states: torch.Tensor
) -> torch.Tensor

Compute draft logits on the configured draft vocabulary.

With config.norm_output unset (EAGLE-3 default) the input is the raw decoder-layer output and the final model.norm is applied here. With config.norm_output set (EAGLE-3.1) forward has already returned the post-norm state, so lm_head is applied directly to avoid a double normalisation.

nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.copy_embeddings_from_target(
target_embedding: torch.nn.Embedding
) -> None

Initialize draft embeddings from the target model embeddings.

When the target model is wrapped with FSDP2, target_embedding.weight is a DTensor sharded across ranks. The draft embedding is a plain nn.Parameter (the draft is not FSDP-wrapped), so a direct copy_ of a DTensor into a regular tensor raises a mixed-type distributed-operator error. Gather to a full local tensor first.

nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.embed_input_ids(
input_ids: torch.Tensor
) -> torch.Tensor

Embed input ids with the draft embedding table.

nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.forward(
input_ids: torch.Tensor,
projected_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: typing.Optional[torch.Tensor] = None,
cache_hidden: typing.Optional[list[list[torch.Tensor]]] = None,
seq_lens: typing.Optional[torch.Tensor] = None
) -> torch.Tensor

Run one full-sequence draft update step.

cache_hidden is the EAGLE-3 TTT cache. Pass [[], []] on the first step of a TTT unroll and the same list object on each subsequent step; the attention layer appends the per-step K and V to it. If None is passed (e.g. from a one-shot evaluation call) a fresh [[], []] is allocated locally — step 0 of TTT is mathematically equivalent to a plain causal forward.

seq_lens (packing) makes Block-1 attention document-level block-causal (eager mask / FA2 varlen); callers must pass per-document position_ids.

nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.freeze_embeddings() -> None

Freeze draft input embeddings.

nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.gradient_checkpointing_disable() -> None

Disable activation checkpointing for the P-EAGLE draft layers.

nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.gradient_checkpointing_enable(
gradient_checkpointing_kwargs = None
) -> None

Enable activation checkpointing for the P-EAGLE draft layers.

Training-only memory knob: recomputes each forward_peagle layer in the backward instead of storing its activations (the EAGLE-3 TTT forward path is unaffected). gradient_checkpointing_kwargs is accepted for HF-API parity but ignored — recompute is always non-reentrant, the only mode compatible with the non-tensor block_mask.

nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.project_hidden_states(
aux_hidden_states: torch.Tensor
) -> torch.Tensor

Project concatenated target aux states from num_aux * H_target to draft hidden size.

When config.fc_norm is set (EAGLE-3.1), the input is split into num_aux_hidden_states equal chunks along the last dim and each chunk is passed through its own RMSNorm in model.fc_norm (the modules are independent, matching vLLM’s upstream implementation). The normalized chunks are then re-concatenated and fed to fc, stabilising the per-aux-state scale before the projection mixes them and removing the speculation-depth drift observed with raw inputs.

nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.set_vocab_mapping(
selected_token_ids: torch.Tensor
) -> None

Populate the d2t / t2d vocab-remap buffers from the draft->target id map.

selected_token_ids has shape [draft_vocab_size]; entry i is the target vocab id of draft id i (the frequency-pruned mapping built by build_eagle3_token_mapping). This writes the two buffers inference engines consume:

  • d2t[i] = selected_token_ids[i] - i — the offset form vLLM expects (target_id = draft_id + d2t[draft_id]);
  • t2d[target_id] = True for every selected target id — the boolean presence mask SGLang consumes.

These must be in the saved checkpoint: without them vLLM/SGLang find no mapping, silently align draft ids to the first draft_vocab_size target ids, and acceptance rate collapses.

No-op when the draft vocab is not compressed (the buffers do not exist and the draft logits are already in target space).

nemo_automodel.components.speculative.eagle.draft_llama._build_causal_mask(
attention_mask: torch.Tensor,
dtype: torch.dtype
) -> torch.Tensor

Build a standard causal + padding mask for SDPA/eager attention.

nemo_automodel.components.speculative.eagle.draft_llama._is_right_padded_attention_mask(
attention_mask: torch.Tensor
) -> bool

Return True when each row is a contiguous valid-prefix followed by padding.

nemo_automodel.components.speculative.eagle.draft_llama._load_flash_attn_func() -> tuple[bool, object | None, object | None]

Best-effort load of flash-attn without breaking eager-only users.

safe_import_from already handles missing modules and missing symbols, but some broken flash-attn installs fail with lower-level loader errors (e.g. ABI / shared-library issues) that should not prevent importing this module for the eager path. Returns the dense flash_attn_func and the flash_attn_varlen_func (used by the packed block-causal path).

nemo_automodel.components.speculative.eagle.draft_llama._seq_lens_to_cu_seqlens(
seq_lens: torch.Tensor,
seq_length: int
) -> tuple[torch.Tensor, int]

Build FlashAttention varlen cu_seqlens (int32) from packed seq_lens.

Documents are flattened row-major to match the varlen attention’s reshape(B*T, ...) token order. Returns (cu_seqlens, max_seqlen).

nemo_automodel.components.speculative.eagle.draft_llama._SUPPORTED_ATTN_IMPLEMENTATIONS = ('eager', 'flash_attention_2')
nemo_automodel.components.speculative.eagle.draft_llama.logger = logging.getLogger(__name__)