nemo_automodel.components.speculative.eagle.draft_llama#

Minimal Llama-based draft model for EAGLE-3 training.

Module naming is aligned to sglang/srt/models/llama_eagle3.py so that a checkpoint produced by this trainer can be loaded directly by SGLang’s LlamaForCausalLMEagle3.load_weights without any key remapping. The state dict layout is:

model.embed_tokens.weight
model.fc.weight
model.layers.0.input_layernorm.weight
model.layers.0.hidden_norm.weight
model.layers.0.post_attention_layernorm.weight
model.layers.0.self_attn.{q,k,v,o}_proj.weight
model.layers.0.mlp.{gate,up,down}_proj.weight
model.norm.weight
lm_head.weight

SGLang merges q_proj/k_proj/v_proj into a single qkv_proj and gate_proj/up_proj into gate_up_proj via its stacked_params_mapping at load time, so the un-fused storage above is the canonical on-disk format.

Module Contents#

Classes#

Eagle3LlamaAttention

EAGLE-3 draft attention over [input_emb, hidden] 2H features.

Eagle3LlamaMLP

Standard Llama SwiGLU MLP on hidden-size activations.

Eagle3LlamaDecoderLayer

Single decoder layer used by the minimal EAGLE-3 draft model.

Eagle3LlamaModel

Inner backbone matching SGLang’s LlamaModel in llama_eagle3.py.

LlamaEagle3DraftModel

Minimal Llama-only EAGLE-3 draft model.

Functions#

_load_flash_attn_func

Best-effort load of flash-attn without breaking eager-only users.

_build_causal_mask

Build a standard causal + padding mask for SDPA/eager attention.

_is_right_padded_attention_mask

Return True when each row is a contiguous valid-prefix followed by padding.

Data#

API#

nemo_automodel.components.speculative.eagle.draft_llama.logger#

‘getLogger(…)’

nemo_automodel.components.speculative.eagle.draft_llama._load_flash_attn_func() tuple[bool, object | None]#

Best-effort load of flash-attn without breaking eager-only users.

safe_import_from already handles missing modules and missing symbols, but some broken flash-attn installs fail with lower-level loader errors (e.g. ABI / shared-library issues) that should not prevent importing this module for the eager path.

nemo_automodel.components.speculative.eagle.draft_llama._SUPPORTED_ATTN_IMPLEMENTATIONS#

(‘eager’, ‘flash_attention_2’)

nemo_automodel.components.speculative.eagle.draft_llama._build_causal_mask(
attention_mask: torch.Tensor,
dtype: torch.dtype,
) torch.Tensor#

Build a standard causal + padding mask for SDPA/eager attention.

nemo_automodel.components.speculative.eagle.draft_llama._is_right_padded_attention_mask(attention_mask: torch.Tensor) bool#

Return True when each row is a contiguous valid-prefix followed by padding.

class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention(config: transformers.LlamaConfig)#

Bases: torch.nn.Module

EAGLE-3 draft attention over [input_emb, hidden] 2H features.

Driven through a shared cache_hidden = [K_list, V_list] pair. At step k (0-indexed), with K_list and V_list already holding entries from steps 0..k-1:

  1. step_idx = len(K_list) (equal to k) gives the rotary phase shift, so the draft’s K_k encodes “this is k tokens into the future”. The shifted cos / sin are computed from position_ids + step_idx.

  2. The freshly projected K, V (after GQA expansion) are appended to the cache lists in place.

  3. The attention output is the EAGLE-3 mixed pattern:

    attn_weights = [ Q @ K_0^T / sqrt(d) + mask ]  ||  diag_1  ||  ...  ||  diag_k

    where diag_i[t] = (Q_t * K_i_t).sum(-1) / sqrt(d). The softmax is taken over the full extended column axis of length T + k. Output is

    out = attn_probs[..., :T] @ V_0  +  sum_{i=1..k} attn_probs[..., T+i-1, None] * V_i.

    In English: Q at position t attends to all K_0 positions (the regular T x T causal block), and additionally to the same position t in each previous draft step i >= 1. Implementation-wise we replace SpecForge llama3_eagle.py’s two O(k^2) cat / add Python loops with single vectorized einsum calls.

cache_hidden is mutated in place; callers are responsible for re-initializing it to [[], []] at the start of each training batch.

Initialization

_project_qkv(
combined_states: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#
_repeat_kv(
k: torch.Tensor,
v: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor]#
forward(
combined_states: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
cache_hidden: list[list[torch.Tensor]],
) torch.Tensor#
_eager_attention_forward(
q: torch.Tensor,
cache_k: list[torch.Tensor],
cache_v: list[torch.Tensor],
attention_mask: torch.Tensor,
step_idx: int,
batch_size: int,
seq_len: int,
) torch.Tensor#
_flash_attention_forward(
q: torch.Tensor,
cache_k: list[torch.Tensor],
cache_v: list[torch.Tensor],
step_idx: int,
batch_size: int,
seq_len: int,
) torch.Tensor#

EAGLE-3 attention via FlashAttention-2 for the T x T causal block.

FA2 covers Block 1 (full T x T causal attention against K_0) and returns the un-normalized log-sum-exp (softmax_lse) alongside the per-token output. The diagonal extension columns (Block 2) for cached steps i >= 1 are computed in eager mode, then merged into a single softmax via the log-space identity lse_full = logaddexp(lse_fa, logsumexp(diag)); the FA output is rescaled by exp(lse_fa - lse_full) and the diagonal contribution is added with weights exp(diag - lse_full).

Padding handling: FA2 is invoked with causal=True. For right-padded batches, padding keys always lie strictly above the diagonal relative to any non-padded query position, so causal masking alone yields the same output as the eager additive padding mask at every valid query position. Outputs at padding query positions differ, but those are masked out at loss time.

class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaMLP(config: transformers.LlamaConfig)#

Bases: torch.nn.Module

Standard Llama SwiGLU MLP on hidden-size activations.

Initialization

forward(hidden_states: torch.Tensor) torch.Tensor#
class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaDecoderLayer(
config: transformers.LlamaConfig,
layer_id: int = 0,
)#

Bases: torch.nn.Module

Single decoder layer used by the minimal EAGLE-3 draft model.

Attribute names mirror SGLang’s LlamaDecoderLayer in sglang/srt/models/llama_eagle3.py: input_layernorm is applied to the per-step token embeddings (embeds in SGLang), hidden_norm is applied to the carried hidden state. is_input_layer is the layer-0 flag that gates the [embeds, hidden] concatenation (always true for our single-layer draft).

Initialization

forward(
input_embeds: torch.Tensor,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
cache_hidden: list[list[torch.Tensor]],
) torch.Tensor#
class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaModel(config: transformers.LlamaConfig)#

Bases: torch.nn.Module

Inner backbone matching SGLang’s LlamaModel in llama_eagle3.py.

Owns embed_tokens, the fc projection from concatenated target aux hidden states to draft hidden size, the (single-element) draft layers ModuleList, and the final norm. The LlamaEagle3DraftModel wrapper around this module adds the top-level lm_head and the training-facing public API.

Initialization

class nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel(config: transformers.LlamaConfig)#

Bases: transformers.PreTrainedModel

Minimal Llama-only EAGLE-3 draft model.

State dict keys match SGLang’s LlamaForCausalLMEagle3 so the saved checkpoint can be loaded by SGLang’s inference engine without any remapping (SGLang’s load_weights fuses q/k/v_proj into qkv_proj and gate/up_proj into gate_up_proj via its standard stacked_params_mapping).

This intentionally starts narrow:

  • Llama config only

  • single draft decoder layer

  • no KV-cache optimization

  • no speculative runtime integration

Initialization

config_class#

None

base_model_prefix#

‘model’

copy_embeddings_from_target(
target_embedding: torch.nn.Embedding,
) None#

Initialize draft embeddings from the target model embeddings.

freeze_embeddings() None#

Freeze draft input embeddings.

project_hidden_states(aux_hidden_states: torch.Tensor) torch.Tensor#

Project concatenated target aux states from num_aux * H_target to draft hidden size.

embed_input_ids(input_ids: torch.Tensor) torch.Tensor#

Embed input ids with the draft embedding table.

compute_logits(hidden_states: torch.Tensor) torch.Tensor#

Compute draft logits on the configured draft vocabulary.

forward(
input_ids: torch.Tensor,
projected_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
cache_hidden: Optional[list[list[torch.Tensor]]] = None,
) torch.Tensor#

Run one full-sequence draft update step.

cache_hidden is the EAGLE-3 TTT cache. Pass [[], []] on the first step of a TTT unroll and the same list object on each subsequent step; the attention layer appends the per-step K and V to it. If None is passed (e.g. from a one-shot evaluation call) a fresh [[], []] is allocated locally – step 0 of TTT is mathematically equivalent to a plain causal forward.