nemo_automodel.components.speculative.eagle.draft_llama
nemo_automodel.components.speculative.eagle.draft_llama
Llama-style dense LLM draft model for EAGLE-3 / EAGLE-3.1 training.
The implementation is config-driven and supports any HuggingFace dense
decoder-only architecture whose layout matches Llama: GQA attention with
optional Q/K/V/O bias (config.attention_bias), SwiGLU MLP with optional
bias (config.mlp_bias), RMSNorm, and rotary position embeddings parameterized
by config.rope_theta / config.rope_scaling. This currently covers Llama,
Phi-3, and Qwen3 dense (Phi-3 omits attention_bias / mlp_bias, which
the attention and MLP layers already read via
getattr(config, "<field>", False); Qwen3 decouples head_dim from
hidden_size / num_attention_heads, which the attention layer reads via
getattr(config, "head_dim", ...)).
Class names and the public architectures string remain LlamaEagle3* for
backward compatibility with already-trained checkpoints and with SGLang’s
LlamaForCausalLMEagle3.load_weights (the saved state dict layout is
unchanged):
model.embed_tokens.weight model.fc.weight model.layers.0.input_layernorm.weight model.layers.0.hidden_norm.weight model.layers.0.post_attention_layernorm.weight model.layers.0.self_attn.{q,k,v,o}_proj.weight model.layers.0.mlp.{gate,up,down}_proj.weight model.norm.weight lm_head.weight
SGLang merges q_proj/k_proj/v_proj into a single qkv_proj and
gate_proj/up_proj into gate_up_proj via its stacked_params_mapping
at load time, so the un-fused storage above is the canonical on-disk format.
EAGLE-3.1 introduces two optional drafter-side toggles that together address the “attention drift” failure mode observed when speculation depth grows:
config.fc_norm(bool, default False) — when True, annn.ModuleListofnum_aux_hidden_statesindependent RMSNorms (each of sizetarget_hidden_size) is applied per chunk before the concatenated auxiliary hidden states entermodel.fc. The on-disk keys aremodel.fc_norm.0.weight,model.fc_norm.1.weight, …; the module layout matches vLLM’s EAGLE-3.1 integration in PR https://github.com/vllm-project/vllm/pull/42764 so checkpoints trained here load directly into vLLM / SGLang.config.norm_output(bool, default False) — when True, the existing final RMSNorm (model.norm) is applied to the per-step hidden state returned byforwardso that the next TTT step (and the lm_head) consume the post-norm state instead of the raw decoder output. Adds no new parameters.
Both flags default to False so EAGLE-3 checkpoints continue to load and
behave identically. Enabling them applies the EAGLE-3.1 drafter toggles to
the Llama-style draft used here; the MLA-backbone Kimi K2.6 draft
(Eagle3DeepseekV2ForCausalLM in lightseekorg/kimi-k2.6-eagle3.1-mla)
is a separate architecture and is not covered by this module.
P-EAGLE (parallel-drafting EAGLE-3) adds one further optional toggle:
config.parallel_drafting(bool, default False) — when True, the draft registers a single learnablemask_hiddenplaceholder of shape[1, 1, num_aux_hidden_states * target_hidden_size](the pre-fcconcatenated-aux dimension) and exposes :meth:LlamaEagle3DraftModel.forward_peagle, a single parallel forward over a flat, COD-subsampled sequence with aflex_attentioncross-depth mask (seepeagle_attention.py/peagle_data.py). The trainer feeds themask_hiddenplaceholder — projected through the sameproject_hidden_statespath as real aux states — at every masked depth (>= 1), together with the masked tokenconfig.mask_token_id, so the draft predicts allconfig.num_depthstokens in one forward instead of autoregressively. The shape, the on-disk keymask_hidden, and the COD config (num_depths/down_sample_ratio/mask_token_id) mirror speculators (https://github.com/vllm-project/speculators/pull/480) so the checkpoint loads into vLLM’s parallel-drafting runtime unchanged. The masked token slot reusesembed_tokens[config.mask_token_id]. SGLang does not serve a P-EAGLE head today (https://github.com/sgl-project/sglang/issues/23171). The flag only ever adds themask_hiddenkey, so EAGLE-3 / EAGLE-3.1 checkpoints round-trip unchanged.
Module Contents
Classes
Functions
Data
_SUPPORTED_ATTN_IMPLEMENTATIONS
API
Bases: _PeagleAttentionMixin, Module
EAGLE-3 draft attention over [input_emb, hidden] 2H features.
Driven through a shared cache_hidden = [K_list, V_list] pair. At
step k (0-indexed), with K_list and V_list already holding
entries from steps 0..k-1:
-
step_idx = len(K_list)(equal tok) gives the rotary phase shift, so the draft’sK_kencodes “this isktokens into the future”. The shiftedcos/sinare computed fromposition_ids + step_idx. -
The freshly projected K, V (after GQA expansion) are appended to the cache lists in place.
-
The attention output is the EAGLE-3 mixed pattern:
attn_weights = [ Q @ K_0^T / sqrt(d) + mask ] || diag_1 || ... || diag_kwhere
diag_i[t] = (Q_t * K_i_t).sum(-1) / sqrt(d). The softmax is taken over the full extended column axis of lengthT + k. Output isout = attn_probs[..., :T] @ V_0 + sum_{i=1..k} attn_probs[..., T+i-1, None] * V_i.In English: Q at position
tattends to all K_0 positions (the regularT x Tcausal block), and additionally to the same positiontin each previous draft stepi >= 1. Implementation-wise we replace SpecForgellama3_eagle.py’s twoO(k^2)cat/addPython loops with single vectorizedeinsumcalls.
cache_hidden is mutated in place; callers are responsible for
re-initializing it to [[], []] at the start of each training
batch.
EAGLE-3 attention via FlashAttention-2 for the T x T causal block.
FA2 covers Block 1 (causal attention against K_0) and returns its
log-sum-exp. The diagonal Block 2 (cached steps i >= 1) is computed
eagerly and merged via the log-space identity
lse_full = logaddexp(lse_fa, logsumexp(diag)): the FA output is scaled
by exp(lse_fa - lse_full) and each diagonal by exp(diag - lse_full).
With cu_seqlens (packing), Block 1 uses flash_attn_varlen_func for
document-level causal attention; the position-wise Block 2 is unchanged.
Document-level causal Block 1 via flash_attn_varlen_func.
Flattens (B, T, H, D) to varlen (total_tokens, H, D) and reshapes
outputs back to [B, H, T, D] / [B, H, T] for the dense-path merge.
Note varlen softmax_lse is [H, total_tokens] (head-major), unlike
the dense [B, H, T] — hence the explicit reshape + shape check.
Bases: _PeagleDecoderLayerMixin, Module
Single decoder layer used by the minimal EAGLE-3 draft model.
Attribute names mirror SGLang’s LlamaDecoderLayer in
sglang/srt/models/llama_eagle3.py: input_layernorm is applied
to the per-step token embeddings (embeds in SGLang),
hidden_norm is applied to the carried hidden state.
is_input_layer is the layer-0 flag that gates the [embeds, hidden] concatenation (always true for our single-layer draft).
Bases: Module
Standard Llama-style SwiGLU MLP on hidden-size activations.
Bases: Module
Inner backbone matching SGLang’s LlamaModel in llama_eagle3.py.
Owns embed_tokens, the fc projection from concatenated target
aux hidden states to draft hidden size, the (single-element) draft
layers ModuleList, and the final norm. The LlamaEagle3DraftModel
wrapper around this module adds the top-level lm_head and the
training-facing public API.
Bases: _PeagleVanillaLayerMixin, Module
Vanilla Llama decoder layer for P-EAGLE depths >= 1.
The EAGLE-3 first layer (:class:Eagle3LlamaDecoderLayer) fuses the token
embedding and the projected target hidden state (2H attention input).
P-EAGLE stacks num_hidden_layers layers; every layer after the first is
a standard Llama block operating on plain hidden states (H), matching
speculators’ decoder_layer_class (a vanilla LlamaDecoderLayer). Only
the P-EAGLE flex-attention path is implemented (these deeper layers do not
participate in the EAGLE-3 cache_hidden TTT recurrence).
Bases: _PeagleDraftMixin, PreTrainedModel
Llama-style dense EAGLE-3 draft model (Llama, Phi-3, Qwen3).
State dict keys match SGLang’s LlamaForCausalLMEagle3 so the saved
checkpoint can be loaded by SGLang’s inference engine without any
remapping (SGLang’s load_weights fuses q/k/v_proj into
qkv_proj and gate/up_proj into gate_up_proj via its
standard stacked_params_mapping).
The class name is retained for checkpoint-architectures compatibility; the
implementation is config-driven and works for any HF dense decoder-only
config that exposes hidden_size, num_attention_heads,
num_key_value_heads, attention_bias, mlp_bias, rope_theta,
and rms_norm_eps. A decoupled head_dim is read via
getattr(config, "head_dim", ...) in the attention layer.
Scope:
- single draft decoder layer
- no KV-cache optimization
- no speculative runtime integration
Compute draft logits on the configured draft vocabulary.
With config.norm_output unset (EAGLE-3 default) the input is the
raw decoder-layer output and the final model.norm is applied
here. With config.norm_output set (EAGLE-3.1) forward has
already returned the post-norm state, so lm_head is applied
directly to avoid a double normalisation.
Initialize draft embeddings from the target model embeddings.
When the target model is wrapped with FSDP2, target_embedding.weight
is a DTensor sharded across ranks. The draft embedding is a plain
nn.Parameter (the draft is not FSDP-wrapped), so a direct
copy_ of a DTensor into a regular tensor raises a mixed-type
distributed-operator error. Gather to a full local tensor first.
Embed input ids with the draft embedding table.
Run one full-sequence draft update step.
cache_hidden is the EAGLE-3 TTT cache. Pass [[], []] on
the first step of a TTT unroll and the same list object on each
subsequent step; the attention layer appends the per-step K and V
to it. If None is passed (e.g. from a one-shot evaluation
call) a fresh [[], []] is allocated locally — step 0 of TTT
is mathematically equivalent to a plain causal forward.
seq_lens (packing) makes Block-1 attention document-level block-causal
(eager mask / FA2 varlen); callers must pass per-document position_ids.
Freeze draft input embeddings.
Disable activation checkpointing for the P-EAGLE draft layers.
Enable activation checkpointing for the P-EAGLE draft layers.
Training-only memory knob: recomputes each forward_peagle layer in the
backward instead of storing its activations (the EAGLE-3 TTT forward
path is unaffected). gradient_checkpointing_kwargs is accepted for
HF-API parity but ignored — recompute is always non-reentrant, the only
mode compatible with the non-tensor block_mask.
Project concatenated target aux states from num_aux * H_target to draft hidden size.
When config.fc_norm is set (EAGLE-3.1), the input is split into
num_aux_hidden_states equal chunks along the last dim and each
chunk is passed through its own RMSNorm in model.fc_norm (the
modules are independent, matching vLLM’s upstream implementation).
The normalized chunks are then re-concatenated and fed to fc,
stabilising the per-aux-state scale before the projection mixes them
and removing the speculation-depth drift observed with raw inputs.
Populate the d2t / t2d vocab-remap buffers from the draft->target id map.
selected_token_ids has shape [draft_vocab_size]; entry i is
the target vocab id of draft id i (the frequency-pruned mapping
built by build_eagle3_token_mapping). This writes the two buffers
inference engines consume:
d2t[i] = selected_token_ids[i] - i— the offset form vLLM expects (target_id = draft_id + d2t[draft_id]);t2d[target_id] = Truefor every selected target id — the boolean presence mask SGLang consumes.
These must be in the saved checkpoint: without them vLLM/SGLang find no
mapping, silently align draft ids to the first draft_vocab_size
target ids, and acceptance rate collapses.
No-op when the draft vocab is not compressed (the buffers do not exist and the draft logits are already in target space).
Build a standard causal + padding mask for SDPA/eager attention.
Return True when each row is a contiguous valid-prefix followed by padding.
Best-effort load of flash-attn without breaking eager-only users.
safe_import_from already handles missing modules and missing symbols, but
some broken flash-attn installs fail with lower-level loader errors
(e.g. ABI / shared-library issues) that should not prevent importing this
module for the eager path. Returns the dense flash_attn_func and the
flash_attn_varlen_func (used by the packed block-causal path).
Build FlashAttention varlen cu_seqlens (int32) from packed seq_lens.
Documents are flattened row-major to match the varlen attention’s
reshape(B*T, ...) token order. Returns (cu_seqlens, max_seqlen).