nemo_automodel.components.speculative.eagle.peagle_draft
nemo_automodel.components.speculative.eagle.peagle_draft
P-EAGLE (parallel-drafting) draft-model behavior, split out of draft_llama.py.
The P-EAGLE forward path is provided as mixins so the EAGLE-3 draft classes in
draft_llama.py opt into it by inheritance without interleaving
parallel_drafting branches through the shared EAGLE-3 code:
- :class:
_PeagleAttentionMixin->Eagle3LlamaAttention - :class:
_PeagleDecoderLayerMixin->Eagle3LlamaDecoderLayer - :class:
_PeagleVanillaLayerMixin->Eagle3LlamaPeagleLayer - :class:
_PeagleDraftMixin->LlamaEagle3DraftModel
The mixins reference only attributes the host classes already define (self)
plus the helpers in this module, so the dependency stays one-way
(draft_llama -> peagle_draft) with no circular import.
Module Contents
Classes
Functions
Data
_peagle_flex_attention_compile_failed
_peagle_flex_attention_compiled
API
P-EAGLE single parallel-group attention for Eagle3LlamaAttention.
P-EAGLE single parallel-group attention.
Unlike the EAGLE-3 cache_hidden recurrence, P-EAGLE flattens all COD
depths into one sequence and attends in a single pass: there is no
per-step rotary phase offset (the depth is baked into position_ids = anchor_pos + depth) and no diagonal-extension cache. Cross-depth
visibility is enforced entirely by block_mask (see
:func:create_peagle_mask_mod), so this is plain scaled-dot-product
attention through flex_attention.
P-EAGLE forward for the fused first layer Eagle3LlamaDecoderLayer.
Decoder-layer variant for the P-EAGLE single parallel forward.
Mirrors :meth:forward (same norms, residuals, MLP and [embeds, hidden] concatenation) but routes attention through
self_attn.forward_peagle with a COD block_mask instead of the
cache_hidden recurrence.
P-EAGLE draft-model methods for LlamaEagle3DraftModel.
Register the learnable P-EAGLE mask_hidden placeholder.
A single learnable placeholder that substitutes for the target auxiliary
hidden states at every masked multi-token-prediction position (COD depths
>= 1). It lives at the pre-fc concatenated-aux dimension
(num_aux_hidden_states * target_hidden_size == model.fc.in_features)
so it flows through project_hidden_states — and fc_norm when set —
exactly like a real aux-hidden vector. Shape [1, 1, 3 * H] and the
on-disk key mask_hidden mirror speculators
(https://github.com/vllm-project/speculators/pull/480) so the checkpoint
loads into vLLM’s parallel-drafting runtime unchanged. Called only when
parallel_drafting is set so EAGLE-3 / EAGLE-3.1 checkpoints round-trip
with no extra keys.
Run one P-EAGLE draft layer, optionally under activation checkpointing.
When gradient_checkpointing is enabled (training only), the layer’s
activations are freed after the forward and recomputed during the
backward (torch.utils.checkpoint), trading one extra forward per
layer for a lower activation-memory peak on the long flattened COD
sequence. use_reentrant=False is required so the non-tensor P-EAGLE
argument (block_mask) passes through and so the recompute composes
with the flex-attention path. With the flag off (the default, and in
eval) the layer runs directly with no overhead.
Construct the COD flex_attention block mask for one sequence.
Run the P-EAGLE single parallel-group forward.
All COD depths are already flattened into one [1, total_sampled]
sequence by the caller:
sampled_input_ids— real token ids at depth-0 slots, the maskedmask_token_idat depth >= 1 slots;sampled_projected_hidden—fc-projected target aux states at depth-0 slots, the projectedmask_hiddenplaceholder elsewhere;position_ids—anchor_pos + depth(the reference position);block_mask— the COD cross-depth visibility mask.
Returns the pre-logits hidden states (post-norm when
config.norm_output is set), one row per sampled element.
Project the learnable P-EAGLE mask_hidden placeholder to draft hidden size.
Returns a [1, hidden_size] tensor obtained by running the
[1, 1, num_aux_hidden_states * target_hidden_size] placeholder through
the same project_hidden_states path (fc plus optional
fc_norm) used for real auxiliary hidden states. The P-EAGLE trainer
scatters the result into every masked COD depth. Only valid when the
draft was built with config.parallel_drafting=True.
P-EAGLE forward for the vanilla deep layer Eagle3LlamaPeagleLayer.
Standard pre-norm Llama block over H hidden states with the COD mask.
Return whether Inductor’s flex-attention lowering supports these tensors.
Run the P-EAGLE flex attention, compiling only when Inductor supports it.