> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.speculative.eagle.draft_llama

Llama-style dense LLM draft model for EAGLE-3 / EAGLE-3.1 training.

The implementation is config-driven and supports any HuggingFace dense
decoder-only architecture whose layout matches Llama: GQA attention with
optional Q/K/V/O bias (`config.attention_bias`), SwiGLU MLP with optional
bias (`config.mlp_bias`), RMSNorm, and rotary position embeddings parameterized
by `config.rope_theta` / `config.rope_scaling`. This currently covers Llama,
Phi-3, and Qwen3 dense (Phi-3 omits `attention_bias` / `mlp_bias`, which
the attention and MLP layers already read via
`getattr(config, "&lt;field&gt;", False)`; Qwen3 decouples `head_dim` from
`hidden_size / num_attention_heads`, which the attention layer reads via
`getattr(config, "head_dim", ...)`).

Class names and the public `architectures` string remain `LlamaEagle3*` for
backward compatibility with already-trained checkpoints and with SGLang's
`LlamaForCausalLMEagle3.load_weights` (the saved state dict layout is
unchanged):

model.embed\_tokens.weight
model.fc.weight
model.layers.0.input\_layernorm.weight
model.layers.0.hidden\_norm.weight
model.layers.0.post\_attention\_layernorm.weight
model.layers.0.self\_attn.\{q,k,v,o}\_proj.weight
model.layers.0.mlp.\{gate,up,down}\_proj.weight
model.norm.weight
lm\_head.weight

SGLang merges `q_proj/k_proj/v_proj` into a single `qkv_proj` and
`gate_proj/up_proj` into `gate_up_proj` via its `stacked_params_mapping`
at load time, so the un-fused storage above is the canonical on-disk format.

EAGLE-3.1 introduces two optional drafter-side toggles that together address
the "attention drift" failure mode observed when speculation depth grows:

* `config.fc_norm` (bool, default False) -- when True, an
  `nn.ModuleList` of `num_aux_hidden_states` independent RMSNorms (each
  of size `target_hidden_size`) is applied per chunk before the
  concatenated auxiliary hidden states enter `model.fc`. The on-disk keys
  are `model.fc_norm.0.weight`, `model.fc_norm.1.weight`, ...; the
  module layout matches vLLM's EAGLE-3.1 integration in PR
  [https://github.com/vllm-project/vllm/pull/42764](https://github.com/vllm-project/vllm/pull/42764) so checkpoints trained
  here load directly into vLLM / SGLang.
* `config.norm_output` (bool, default False) -- when True, the existing
  final RMSNorm (`model.norm`) is applied to the per-step hidden state
  returned by `forward` so that the next TTT step (and the lm\_head)
  consume the post-norm state instead of the raw decoder output. Adds no
  new parameters.

Both flags default to False so EAGLE-3 checkpoints continue to load and
behave identically. Enabling them applies the EAGLE-3.1 drafter toggles to
the Llama-style draft used here; the MLA-backbone Kimi K2.6 draft
(`Eagle3DeepseekV2ForCausalLM` in `lightseekorg/kimi-k2.6-eagle3.1-mla`)
is a separate architecture and is not covered by this module.

P-EAGLE (parallel-drafting EAGLE-3) adds one further optional toggle:

* `config.parallel_drafting` (bool, default False) -- when True, the draft
  registers a single learnable `mask_hidden` placeholder of shape
  `[1, 1, num_aux_hidden_states * target_hidden_size]` (the pre-`fc`
  concatenated-aux dimension) and exposes :meth:`LlamaEagle3DraftModel.forward_peagle`,
  a single parallel forward over a flat, COD-subsampled sequence with a
  `flex_attention` cross-depth mask (see `peagle_attention.py` /
  `peagle_data.py`). The trainer feeds the `mask_hidden` placeholder --
  projected through the same `project_hidden_states` path as real aux states --
  at every masked depth (`&gt;= 1`), together with the masked token
  `config.mask_token_id`, so the draft predicts all `config.num_depths` tokens
  in one forward instead of autoregressively. The shape, the on-disk key
  `mask_hidden`, and the COD config (`num_depths` / `down_sample_ratio` /
  `mask_token_id`) mirror speculators
  ([https://github.com/vllm-project/speculators/pull/480](https://github.com/vllm-project/speculators/pull/480)) so the checkpoint loads
  into vLLM's parallel-drafting runtime unchanged. The masked token slot reuses
  `embed_tokens[config.mask_token_id]`. SGLang does not serve a P-EAGLE head
  today ([https://github.com/sgl-project/sglang/issues/23171](https://github.com/sgl-project/sglang/issues/23171)). The flag only ever
  adds the `mask_hidden` key, so EAGLE-3 / EAGLE-3.1 checkpoints round-trip
  unchanged.

## Module Contents

### Classes

| Name                                                                                                          | Description                                                         |
| ------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------- |
| [`Eagle3LlamaAttention`](#nemo_automodel-components-speculative-eagle-draft_llama-Eagle3LlamaAttention)       | EAGLE-3 draft attention over `[input_emb, hidden]` 2H features.     |
| [`Eagle3LlamaDecoderLayer`](#nemo_automodel-components-speculative-eagle-draft_llama-Eagle3LlamaDecoderLayer) | Single decoder layer used by the minimal EAGLE-3 draft model.       |
| [`Eagle3LlamaMLP`](#nemo_automodel-components-speculative-eagle-draft_llama-Eagle3LlamaMLP)                   | Standard Llama-style SwiGLU MLP on hidden-size activations.         |
| [`Eagle3LlamaModel`](#nemo_automodel-components-speculative-eagle-draft_llama-Eagle3LlamaModel)               | Inner backbone matching SGLang's `LlamaModel` in `llama_eagle3.py`. |
| [`Eagle3LlamaPeagleLayer`](#nemo_automodel-components-speculative-eagle-draft_llama-Eagle3LlamaPeagleLayer)   | Vanilla Llama decoder layer for P-EAGLE depths `&gt;= 1`.           |
| [`LlamaEagle3DraftModel`](#nemo_automodel-components-speculative-eagle-draft_llama-LlamaEagle3DraftModel)     | Llama-style dense EAGLE-3 draft model (Llama, Phi-3, Qwen3).        |

### Functions

| Name                                                                                                                          | Description                                                                 |
| ----------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------- |
| [`_build_causal_mask`](#nemo_automodel-components-speculative-eagle-draft_llama-_build_causal_mask)                           | Build a standard causal + padding mask for SDPA/eager attention.            |
| [`_is_right_padded_attention_mask`](#nemo_automodel-components-speculative-eagle-draft_llama-_is_right_padded_attention_mask) | Return True when each row is a contiguous valid-prefix followed by padding. |
| [`_load_flash_attn_func`](#nemo_automodel-components-speculative-eagle-draft_llama-_load_flash_attn_func)                     | Best-effort load of flash-attn without breaking eager-only users.           |
| [`_seq_lens_to_cu_seqlens`](#nemo_automodel-components-speculative-eagle-draft_llama-_seq_lens_to_cu_seqlens)                 | Build FlashAttention varlen `cu_seqlens` (int32) from packed `seq_lens`.    |

### Data

[`_SUPPORTED_ATTN_IMPLEMENTATIONS`](#nemo_automodel-components-speculative-eagle-draft_llama-_SUPPORTED_ATTN_IMPLEMENTATIONS)

[`logger`](#nemo_automodel-components-speculative-eagle-draft_llama-logger)

### API

```python
class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention(
    config: transformers.PretrainedConfig,
    fuse_input: bool = True
)
```

**Bases:** [\_PeagleAttentionMixin](/nemo-automodel/nemo_automodel/components/speculative/eagle/peagle_draft#nemo_automodel-components-speculative-eagle-peagle_draft-_PeagleAttentionMixin), `Module`

EAGLE-3 draft attention over `[input_emb, hidden]` 2H features.

Driven through a shared `cache_hidden = [K_list, V_list]` pair. At
step `k` (0-indexed), with `K_list` and `V_list` already holding
entries from steps `0..k-1`:

1. `step_idx = len(K_list)` (equal to `k`) gives the rotary phase
   shift, so the draft's `K_k` encodes "this is `k` tokens into
   the future". The shifted `cos` / `sin` are computed from
   `position_ids + step_idx`.
2. The freshly projected K, V (after GQA expansion) are appended to
   the cache lists in place.
3. The attention output is the EAGLE-3 mixed pattern:

   `attn_weights = [ Q @ K_0^T / sqrt(d) + mask ]  ||  diag_1  ||  ...  ||  diag_k`

   where `diag_i[t] = (Q_t * K_i_t).sum(-1) / sqrt(d)`. The softmax
   is taken over the full extended column axis of length `T + k`.
   Output is

   `out = attn_probs[..., :T] @ V_0  +  sum_&#123;i=1..k&#125; attn_probs[..., T+i-1, None] * V_i`.

   In English: Q at position `t` attends to all K\_0 positions (the
   regular `T x T` causal block), and additionally to the *same*
   position `t` in each previous draft step `i &gt;= 1`.
   Implementation-wise we replace SpecForge `llama3_eagle.py`'s
   two `O(k^2)` `cat` / `add` Python loops with single
   vectorized `einsum` calls.

`cache_hidden` is mutated in place; callers are responsible for
re-initializing it to `[[], []]` at the start of each training
batch.

```python
nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention._eager_attention_forward(
    q: torch.Tensor,
    cache_k: list[torch.Tensor],
    cache_v: list[torch.Tensor],
    attention_mask: torch.Tensor,
    step_idx: int,
    batch_size: int,
    seq_len: int
) -> torch.Tensor
```

```python
nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention._flash_attention_forward(
    q: torch.Tensor,
    cache_k: list[torch.Tensor],
    cache_v: list[torch.Tensor],
    step_idx: int,
    batch_size: int,
    seq_len: int,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: int | None = None
) -> torch.Tensor
```

EAGLE-3 attention via FlashAttention-2 for the T x T causal block.

FA2 covers Block 1 (causal attention against `K_0`) and returns its
log-sum-exp. The diagonal Block 2 (cached steps `i &gt;= 1`) is computed
eagerly and merged via the log-space identity
`lse_full = logaddexp(lse_fa, logsumexp(diag))`: the FA output is scaled
by `exp(lse_fa - lse_full)` and each diagonal by `exp(diag - lse_full)`.

With `cu_seqlens` (packing), Block 1 uses `flash_attn_varlen_func` for
document-level causal attention; the position-wise Block 2 is unchanged.

```python
nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention._flash_block1_varlen(
    q_fa: torch.Tensor,
    k0_fa: torch.Tensor,
    v0_fa: torch.Tensor,
    cu_seqlens: torch.Tensor,
    max_seqlen: int,
    batch_size: int,
    seq_len: int
) -> tuple[torch.Tensor, torch.Tensor]
```

Document-level causal Block 1 via `flash_attn_varlen_func`.

Flattens `(B, T, H, D)` to varlen `(total_tokens, H, D)` and reshapes
outputs back to `[B, H, T, D]` / `[B, H, T]` for the dense-path merge.
Note varlen `softmax_lse` is `[H, total_tokens]` (head-major), unlike
the dense `[B, H, T]` -- hence the explicit reshape + shape check.

```python
nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention._project_qkv(
    combined_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]
```

```python
nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention._repeat_kv(
    k: torch.Tensor,
    v: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]
```

```python
nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaAttention.forward(
    combined_states: torch.Tensor,
    attention_mask: torch.Tensor,
    position_ids: torch.Tensor,
    cache_hidden: list[list[torch.Tensor]],
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: int | None = None
) -> torch.Tensor
```

```python
class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaDecoderLayer(
    config: transformers.PretrainedConfig,
    layer_id: int = 0
)
```

**Bases:** [\_PeagleDecoderLayerMixin](/nemo-automodel/nemo_automodel/components/speculative/eagle/peagle_draft#nemo_automodel-components-speculative-eagle-peagle_draft-_PeagleDecoderLayerMixin), `Module`

Single decoder layer used by the minimal EAGLE-3 draft model.

Attribute names mirror SGLang's `LlamaDecoderLayer` in
`sglang/srt/models/llama_eagle3.py`: `input_layernorm` is applied
to the per-step token embeddings (`embeds` in SGLang),
`hidden_norm` is applied to the carried hidden state.
`is_input_layer` is the layer-0 flag that gates the `[embeds,
hidden]` concatenation (always true for our single-layer draft).

```python
nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaDecoderLayer.forward(
    input_embeds: torch.Tensor,
    hidden_states: torch.Tensor,
    attention_mask: torch.Tensor,
    position_ids: torch.Tensor,
    cache_hidden: list[list[torch.Tensor]],
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: int | None = None
) -> torch.Tensor
```

```python
class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaMLP(
    config: transformers.PretrainedConfig
)
```

**Bases:** `Module`

Standard Llama-style SwiGLU MLP on hidden-size activations.

```python
nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaMLP.forward(
    hidden_states: torch.Tensor
) -> torch.Tensor
```

```python
class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaModel(
    config: transformers.PretrainedConfig
)
```

**Bases:** `Module`

Inner backbone matching SGLang's `LlamaModel` in `llama_eagle3.py`.

Owns `embed_tokens`, the `fc` projection from concatenated target
aux hidden states to draft hidden size, the (single-element) draft
`layers` ModuleList, and the final `norm`. The `LlamaEagle3DraftModel`
wrapper around this module adds the top-level `lm_head` and the
training-facing public API.

```python
class nemo_automodel.components.speculative.eagle.draft_llama.Eagle3LlamaPeagleLayer(
    config: transformers.PretrainedConfig,
    layer_id: int
)
```

**Bases:** [\_PeagleVanillaLayerMixin](/nemo-automodel/nemo_automodel/components/speculative/eagle/peagle_draft#nemo_automodel-components-speculative-eagle-peagle_draft-_PeagleVanillaLayerMixin), `Module`

Vanilla Llama decoder layer for P-EAGLE depths `&gt;= 1`.

The EAGLE-3 first layer (:class:`Eagle3LlamaDecoderLayer`) fuses the token
embedding and the projected target hidden state (`2H` attention input).
P-EAGLE stacks `num_hidden_layers` layers; every layer after the first is
a standard Llama block operating on plain hidden states (`H`), matching
speculators' `decoder_layer_class` (a vanilla `LlamaDecoderLayer`). Only
the P-EAGLE flex-attention path is implemented (these deeper layers do not
participate in the EAGLE-3 `cache_hidden` TTT recurrence).

```python
class nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel(
    config: transformers.PretrainedConfig
)
```

**Bases:** [\_PeagleDraftMixin](/nemo-automodel/nemo_automodel/components/speculative/eagle/peagle_draft#nemo_automodel-components-speculative-eagle-peagle_draft-_PeagleDraftMixin), `PreTrainedModel`

Llama-style dense EAGLE-3 draft model (Llama, Phi-3, Qwen3).

State dict keys match SGLang's `LlamaForCausalLMEagle3` so the saved
checkpoint can be loaded by SGLang's inference engine without any
remapping (SGLang's `load_weights` fuses `q/k/v_proj` into
`qkv_proj` and `gate/up_proj` into `gate_up_proj` via its
standard `stacked_params_mapping`).

The class name is retained for checkpoint-architectures compatibility; the
implementation is config-driven and works for any HF dense decoder-only
config that exposes `hidden_size`, `num_attention_heads`,
`num_key_value_heads`, `attention_bias`, `mlp_bias`, `rope_theta`,
and `rms_norm_eps`. A decoupled `head_dim` is read via
`getattr(config, "head_dim", ...)` in the attention layer.

Scope:

* single draft decoder layer
* no KV-cache optimization
* no speculative runtime integration

```python
nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.compute_logits(
    hidden_states: torch.Tensor
) -> torch.Tensor
```

Compute draft logits on the configured draft vocabulary.

With `config.norm_output` unset (EAGLE-3 default) the input is the
raw decoder-layer output and the final `model.norm` is applied
here. With `config.norm_output` set (EAGLE-3.1) `forward` has
already returned the post-norm state, so `lm_head` is applied
directly to avoid a double normalisation.

```python
nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.copy_embeddings_from_target(
    target_embedding: torch.nn.Embedding
) -> None
```

Initialize draft embeddings from the target model embeddings.

When the target model is wrapped with FSDP2, `target_embedding.weight`
is a `DTensor` sharded across ranks.  The draft embedding is a plain
`nn.Parameter` (the draft is not FSDP-wrapped), so a direct
`copy_` of a DTensor into a regular tensor raises a mixed-type
distributed-operator error.  Gather to a full local tensor first.

```python
nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.embed_input_ids(
    input_ids: torch.Tensor
) -> torch.Tensor
```

Embed input ids with the draft embedding table.

```python
nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.forward(
    input_ids: torch.Tensor,
    projected_hidden_states: torch.Tensor,
    attention_mask: torch.Tensor,
    position_ids: typing.Optional[torch.Tensor] = None,
    cache_hidden: typing.Optional[list[list[torch.Tensor]]] = None,
    seq_lens: typing.Optional[torch.Tensor] = None
) -> torch.Tensor
```

Run one full-sequence draft update step.

`cache_hidden` is the EAGLE-3 TTT cache. Pass `[[], []]` on
the first step of a TTT unroll and the same list object on each
subsequent step; the attention layer appends the per-step K and V
to it. If `None` is passed (e.g. from a one-shot evaluation
call) a fresh `[[], []]` is allocated locally -- step 0 of TTT
is mathematically equivalent to a plain causal forward.

`seq_lens` (packing) makes Block-1 attention document-level block-causal
(eager mask / FA2 varlen); callers must pass per-document `position_ids`.

```python
nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.freeze_embeddings() -> None
```

Freeze draft input embeddings.

```python
nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.gradient_checkpointing_disable() -> None
```

Disable activation checkpointing for the P-EAGLE draft layers.

```python
nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.gradient_checkpointing_enable(
    gradient_checkpointing_kwargs = None
) -> None
```

Enable activation checkpointing for the P-EAGLE draft layers.

Training-only memory knob: recomputes each `forward_peagle` layer in the
backward instead of storing its activations (the EAGLE-3 TTT `forward`
path is unaffected). `gradient_checkpointing_kwargs` is accepted for
HF-API parity but ignored -- recompute is always non-reentrant, the only
mode compatible with the non-tensor `block_mask`.

```python
nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.project_hidden_states(
    aux_hidden_states: torch.Tensor
) -> torch.Tensor
```

Project concatenated target aux states from `num_aux * H_target` to draft hidden size.

When `config.fc_norm` is set (EAGLE-3.1), the input is split into
`num_aux_hidden_states` equal chunks along the last dim and each
chunk is passed through its own RMSNorm in `model.fc_norm` (the
modules are independent, matching vLLM's upstream implementation).
The normalized chunks are then re-concatenated and fed to `fc`,
stabilising the per-aux-state scale before the projection mixes them
and removing the speculation-depth drift observed with raw inputs.

```python
nemo_automodel.components.speculative.eagle.draft_llama.LlamaEagle3DraftModel.set_vocab_mapping(
    selected_token_ids: torch.Tensor
) -> None
```

Populate the `d2t` / `t2d` vocab-remap buffers from the draft->target id map.

`selected_token_ids` has shape `[draft_vocab_size]`; entry `i` is
the *target* vocab id of draft id `i` (the frequency-pruned mapping
built by `build_eagle3_token_mapping`). This writes the two buffers
inference engines consume:

* `d2t[i] = selected_token_ids[i] - i` -- the offset form vLLM expects
  (`target_id = draft_id + d2t[draft_id]`);
* `t2d[target_id] = True` for every selected target id -- the boolean
  presence mask SGLang consumes.

These must be in the saved checkpoint: without them vLLM/SGLang find no
mapping, silently align draft ids to the first `draft_vocab_size`
target ids, and acceptance rate collapses.

No-op when the draft vocab is not compressed (the buffers do not exist
and the draft logits are already in target space).

```python
nemo_automodel.components.speculative.eagle.draft_llama._build_causal_mask(
    attention_mask: torch.Tensor,
    dtype: torch.dtype
) -> torch.Tensor
```

Build a standard causal + padding mask for SDPA/eager attention.

```python
nemo_automodel.components.speculative.eagle.draft_llama._is_right_padded_attention_mask(
    attention_mask: torch.Tensor
) -> bool
```

Return True when each row is a contiguous valid-prefix followed by padding.

```python
nemo_automodel.components.speculative.eagle.draft_llama._load_flash_attn_func() -> tuple[bool, object | None, object | None]
```

Best-effort load of flash-attn without breaking eager-only users.

`safe_import_from` already handles missing modules and missing symbols, but
some broken `flash-attn` installs fail with lower-level loader errors
(e.g. ABI / shared-library issues) that should not prevent importing this
module for the eager path. Returns the dense `flash_attn_func` and the
`flash_attn_varlen_func` (used by the packed block-causal path).

```python
nemo_automodel.components.speculative.eagle.draft_llama._seq_lens_to_cu_seqlens(
    seq_lens: torch.Tensor,
    seq_length: int
) -> tuple[torch.Tensor, int]
```

Build FlashAttention varlen `cu_seqlens` (int32) from packed `seq_lens`.

Documents are flattened row-major to match the varlen attention's
`reshape(B*T, ...)` token order. Returns `(cu_seqlens, max_seqlen)`.

```python
nemo_automodel.components.speculative.eagle.draft_llama._SUPPORTED_ATTN_IMPLEMENTATIONS = ('eager', 'flash_attention_2')
```

```python
nemo_automodel.components.speculative.eagle.draft_llama.logger = logging.getLogger(__name__)
```