nemo_automodel.components.speculative.eagle.draft_llama#
Minimal Llama-based draft model for EAGLE-3 training.
Module naming is aligned to sglang/srt/models/llama_eagle3.py so that a
checkpoint produced by this trainer can be loaded directly by SGLang’s
LlamaForCausalLMEagle3.load_weights without any key remapping. The state
dict layout is:
model.embed_tokens.weight
model.fc.weight
model.layers.0.input_layernorm.weight
model.layers.0.hidden_norm.weight
model.layers.0.post_attention_layernorm.weight
model.layers.0.self_attn.{q,k,v,o}_proj.weight
model.layers.0.mlp.{gate,up,down}_proj.weight
model.norm.weight
lm_head.weight
SGLang merges q_proj/k_proj/v_proj into a single qkv_proj and
gate_proj/up_proj into gate_up_proj via its stacked_params_mapping
at load time, so the un-fused storage above is the canonical on-disk format.
Module Contents#
Classes#
EAGLE-3 draft attention over |
|
Standard Llama SwiGLU MLP on hidden-size activations. |
|
Single decoder layer used by the minimal EAGLE-3 draft model. |
|
Inner backbone matching SGLang’s |
|
Minimal Llama-only EAGLE-3 draft model. |
Functions#
Best-effort load of flash-attn without breaking eager-only users. |
|
Build a standard causal + padding mask for SDPA/eager attention. |
|
Return True when each row is a contiguous valid-prefix followed by padding. |
Data#
API#
- nemo_automodel.components.speculative.eagle.draft_llama.logger#
‘getLogger(…)’
- nemo_automodel.components.speculative.eagle.draft_llama._load_flash_attn_func() tuple[bool, object | None]#
Best-effort load of flash-attn without breaking eager-only users.
safe_import_fromalready handles missing modules and missing symbols, but some brokenflash-attninstalls fail with lower-level loader errors (e.g. ABI / shared-library issues) that should not prevent importing this module for the eager path.
- nemo_automodel.components.speculative.eagle.draft_llama._SUPPORTED_ATTN_IMPLEMENTATIONS#
(‘eager’, ‘flash_attention_2’)
- nemo_automodel.components.speculative.eagle.draft_llama._build_causal_mask(
- attention_mask: torch.Tensor,
- dtype: torch.dtype,
Build a standard causal + padding mask for SDPA/eager attention.
- nemo_automodel.components.speculative.eagle.draft_llama._is_right_padded_attention_mask(attention_mask: torch.Tensor) bool#
Return True when each row is a contiguous valid-prefix followed by padding.
- class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention(config: transformers.LlamaConfig)#
Bases:
torch.nn.ModuleEAGLE-3 draft attention over
[input_emb, hidden]2H features.Driven through a shared
cache_hidden = [K_list, V_list]pair. At stepk(0-indexed), withK_listandV_listalready holding entries from steps0..k-1:step_idx = len(K_list)(equal tok) gives the rotary phase shift, so the draft’sK_kencodes “this isktokens into the future”. The shiftedcos/sinare computed fromposition_ids + step_idx.The freshly projected K, V (after GQA expansion) are appended to the cache lists in place.
The attention output is the EAGLE-3 mixed pattern:
attn_weights = [ Q @ K_0^T / sqrt(d) + mask ] || diag_1 || ... || diag_kwhere
diag_i[t] = (Q_t * K_i_t).sum(-1) / sqrt(d). The softmax is taken over the full extended column axis of lengthT + k. Output isout = attn_probs[..., :T] @ V_0 + sum_{i=1..k} attn_probs[..., T+i-1, None] * V_i.In English: Q at position
tattends to all K_0 positions (the regularT x Tcausal block), and additionally to the same positiontin each previous draft stepi >= 1. Implementation-wise we replace SpecForgellama3_eagle.py’s twoO(k^2)cat/addPython loops with single vectorizedeinsumcalls.
cache_hiddenis mutated in place; callers are responsible for re-initializing it to[[], []]at the start of each training batch.Initialization
- _project_qkv(
- combined_states: torch.Tensor,
- _repeat_kv(
- k: torch.Tensor,
- v: torch.Tensor,
- forward(
- combined_states: torch.Tensor,
- attention_mask: torch.Tensor,
- position_ids: torch.Tensor,
- cache_hidden: list[list[torch.Tensor]],
- _eager_attention_forward(
- q: torch.Tensor,
- cache_k: list[torch.Tensor],
- cache_v: list[torch.Tensor],
- attention_mask: torch.Tensor,
- step_idx: int,
- batch_size: int,
- seq_len: int,
- _flash_attention_forward(
- q: torch.Tensor,
- cache_k: list[torch.Tensor],
- cache_v: list[torch.Tensor],
- step_idx: int,
- batch_size: int,
- seq_len: int,
EAGLE-3 attention via FlashAttention-2 for the T x T causal block.
FA2 covers Block 1 (full
T x Tcausal attention againstK_0) and returns the un-normalized log-sum-exp (softmax_lse) alongside the per-token output. The diagonal extension columns (Block 2) for cached stepsi >= 1are computed in eager mode, then merged into a single softmax via the log-space identitylse_full = logaddexp(lse_fa, logsumexp(diag)); the FA output is rescaled byexp(lse_fa - lse_full)and the diagonal contribution is added with weightsexp(diag - lse_full).Padding handling: FA2 is invoked with
causal=True. For right-padded batches, padding keys always lie strictly above the diagonal relative to any non-padded query position, so causal masking alone yields the same output as the eager additive padding mask at every valid query position. Outputs at padding query positions differ, but those are masked out at loss time.
- class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaMLP(config: transformers.LlamaConfig)#
Bases:
torch.nn.ModuleStandard Llama SwiGLU MLP on hidden-size activations.
Initialization
- forward(hidden_states: torch.Tensor) torch.Tensor#
- class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaDecoderLayer(
- config: transformers.LlamaConfig,
- layer_id: int = 0,
Bases:
torch.nn.ModuleSingle decoder layer used by the minimal EAGLE-3 draft model.
Attribute names mirror SGLang’s
LlamaDecoderLayerinsglang/srt/models/llama_eagle3.py:input_layernormis applied to the per-step token embeddings (embedsin SGLang),hidden_normis applied to the carried hidden state.is_input_layeris the layer-0 flag that gates the[embeds, hidden]concatenation (always true for our single-layer draft).Initialization
- forward(
- input_embeds: torch.Tensor,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- position_ids: torch.Tensor,
- cache_hidden: list[list[torch.Tensor]],
- class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaModel(config: transformers.LlamaConfig)#
Bases:
torch.nn.ModuleInner backbone matching SGLang’s
LlamaModelinllama_eagle3.py.Owns
embed_tokens, thefcprojection from concatenated target aux hidden states to draft hidden size, the (single-element) draftlayersModuleList, and the finalnorm. TheLlamaEagle3DraftModelwrapper around this module adds the top-levellm_headand the training-facing public API.Initialization
- class nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel(config: transformers.LlamaConfig)#
Bases:
transformers.PreTrainedModelMinimal Llama-only EAGLE-3 draft model.
State dict keys match SGLang’s
LlamaForCausalLMEagle3so the saved checkpoint can be loaded by SGLang’s inference engine without any remapping (SGLang’sload_weightsfusesq/k/v_projintoqkv_projandgate/up_projintogate_up_projvia its standardstacked_params_mapping).This intentionally starts narrow:
Llama config only
single draft decoder layer
no KV-cache optimization
no speculative runtime integration
Initialization
- config_class#
None
- base_model_prefix#
‘model’
- copy_embeddings_from_target(
- target_embedding: torch.nn.Embedding,
Initialize draft embeddings from the target model embeddings.
- freeze_embeddings() None#
Freeze draft input embeddings.
Project concatenated target aux states from
num_aux * H_targetto draft hidden size.
- embed_input_ids(input_ids: torch.Tensor) torch.Tensor#
Embed input ids with the draft embedding table.
- compute_logits(hidden_states: torch.Tensor) torch.Tensor#
Compute draft logits on the configured draft vocabulary.
- forward(
- input_ids: torch.Tensor,
- projected_hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- position_ids: Optional[torch.Tensor] = None,
- cache_hidden: Optional[list[list[torch.Tensor]]] = None,
Run one full-sequence draft update step.
cache_hiddenis the EAGLE-3 TTT cache. Pass[[], []]on the first step of a TTT unroll and the same list object on each subsequent step; the attention layer appends the per-step K and V to it. IfNoneis passed (e.g. from a one-shot evaluation call) a fresh[[], []]is allocated locally – step 0 of TTT is mathematically equivalent to a plain causal forward.