nemo_automodel.components.speculative.eagle.peagle_draft

View as Markdown

P-EAGLE (parallel-drafting) draft-model behavior, split out of draft_llama.py.

The P-EAGLE forward path is provided as mixins so the EAGLE-3 draft classes in draft_llama.py opt into it by inheritance without interleaving parallel_drafting branches through the shared EAGLE-3 code:

  • :class:_PeagleAttentionMixin -> Eagle3LlamaAttention
  • :class:_PeagleDecoderLayerMixin -> Eagle3LlamaDecoderLayer
  • :class:_PeagleVanillaLayerMixin -> Eagle3LlamaPeagleLayer
  • :class:_PeagleDraftMixin -> LlamaEagle3DraftModel

The mixins reference only attributes the host classes already define (self) plus the helpers in this module, so the dependency stays one-way (draft_llama -> peagle_draft) with no circular import.

Module Contents

Classes

NameDescription
_PeagleAttentionMixinP-EAGLE single parallel-group attention for Eagle3LlamaAttention.
_PeagleDecoderLayerMixinP-EAGLE forward for the fused first layer Eagle3LlamaDecoderLayer.
_PeagleDraftMixinP-EAGLE draft-model methods for LlamaEagle3DraftModel.
_PeagleVanillaLayerMixinP-EAGLE forward for the vanilla deep layer Eagle3LlamaPeagleLayer.

Functions

NameDescription
_peagle_compile_supportedReturn whether Inductor’s flex-attention lowering supports these tensors.
_peagle_flex_attentionRun the P-EAGLE flex attention, compiling only when Inductor supports it.

Data

_peagle_flex_attention_compile_failed

_peagle_flex_attention_compiled

logger

API

class nemo_automodel.components.speculative.eagle.peagle_draft._PeagleAttentionMixin()

P-EAGLE single parallel-group attention for Eagle3LlamaAttention.

nemo_automodel.components.speculative.eagle.peagle_draft._PeagleAttentionMixin.forward_peagle(
combined_states: torch.Tensor,
position_ids: torch.Tensor,
block_mask
) -> torch.Tensor

P-EAGLE single parallel-group attention.

Unlike the EAGLE-3 cache_hidden recurrence, P-EAGLE flattens all COD depths into one sequence and attends in a single pass: there is no per-step rotary phase offset (the depth is baked into position_ids = anchor_pos + depth) and no diagonal-extension cache. Cross-depth visibility is enforced entirely by block_mask (see :func:create_peagle_mask_mod), so this is plain scaled-dot-product attention through flex_attention.

class nemo_automodel.components.speculative.eagle.peagle_draft._PeagleDecoderLayerMixin()

P-EAGLE forward for the fused first layer Eagle3LlamaDecoderLayer.

nemo_automodel.components.speculative.eagle.peagle_draft._PeagleDecoderLayerMixin.forward_peagle(
input_embeds: torch.Tensor,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
block_mask
) -> torch.Tensor

Decoder-layer variant for the P-EAGLE single parallel forward.

Mirrors :meth:forward (same norms, residuals, MLP and [embeds, hidden] concatenation) but routes attention through self_attn.forward_peagle with a COD block_mask instead of the cache_hidden recurrence.

class nemo_automodel.components.speculative.eagle.peagle_draft._PeagleDraftMixin()

P-EAGLE draft-model methods for LlamaEagle3DraftModel.

nemo_automodel.components.speculative.eagle.peagle_draft._PeagleDraftMixin._init_peagle_parameters(
config
) -> None

Register the learnable P-EAGLE mask_hidden placeholder.

A single learnable placeholder that substitutes for the target auxiliary hidden states at every masked multi-token-prediction position (COD depths >= 1). It lives at the pre-fc concatenated-aux dimension (num_aux_hidden_states * target_hidden_size == model.fc.in_features) so it flows through project_hidden_states — and fc_norm when set — exactly like a real aux-hidden vector. Shape [1, 1, 3 * H] and the on-disk key mask_hidden mirror speculators (https://github.com/vllm-project/speculators/pull/480) so the checkpoint loads into vLLM’s parallel-drafting runtime unchanged. Called only when parallel_drafting is set so EAGLE-3 / EAGLE-3.1 checkpoints round-trip with no extra keys.

nemo_automodel.components.speculative.eagle.peagle_draft._PeagleDraftMixin._run_draft_layer(
layer_fn,
args = (),
kwargs = {}
)

Run one P-EAGLE draft layer, optionally under activation checkpointing.

When gradient_checkpointing is enabled (training only), the layer’s activations are freed after the forward and recomputed during the backward (torch.utils.checkpoint), trading one extra forward per layer for a lower activation-memory peak on the long flattened COD sequence. use_reentrant=False is required so the non-tensor P-EAGLE argument (block_mask) passes through and so the recompute composes with the flex-attention path. With the flag off (the default, and in eval) the layer runs directly with no overhead.

nemo_automodel.components.speculative.eagle.peagle_draft._PeagleDraftMixin.build_peagle_block_mask(
anchor_pos,
depth,
lengths,
total_seq_len
)

Construct the COD flex_attention block mask for one sequence.

nemo_automodel.components.speculative.eagle.peagle_draft._PeagleDraftMixin.forward_peagle(
sampled_input_ids: torch.Tensor,
sampled_projected_hidden: torch.Tensor,
position_ids: torch.Tensor,
block_mask
) -> torch.Tensor

Run the P-EAGLE single parallel-group forward.

All COD depths are already flattened into one [1, total_sampled] sequence by the caller:

  • sampled_input_ids — real token ids at depth-0 slots, the masked mask_token_id at depth >= 1 slots;
  • sampled_projected_hiddenfc-projected target aux states at depth-0 slots, the projected mask_hidden placeholder elsewhere;
  • position_idsanchor_pos + depth (the reference position);
  • block_mask — the COD cross-depth visibility mask.

Returns the pre-logits hidden states (post-norm when config.norm_output is set), one row per sampled element.

nemo_automodel.components.speculative.eagle.peagle_draft._PeagleDraftMixin.masked_projected_hidden() -> torch.Tensor

Project the learnable P-EAGLE mask_hidden placeholder to draft hidden size.

Returns a [1, hidden_size] tensor obtained by running the [1, 1, num_aux_hidden_states * target_hidden_size] placeholder through the same project_hidden_states path (fc plus optional fc_norm) used for real auxiliary hidden states. The P-EAGLE trainer scatters the result into every masked COD depth. Only valid when the draft was built with config.parallel_drafting=True.

class nemo_automodel.components.speculative.eagle.peagle_draft._PeagleVanillaLayerMixin()

P-EAGLE forward for the vanilla deep layer Eagle3LlamaPeagleLayer.

nemo_automodel.components.speculative.eagle.peagle_draft._PeagleVanillaLayerMixin.forward_peagle(
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
block_mask
) -> torch.Tensor

Standard pre-norm Llama block over H hidden states with the COD mask.

nemo_automodel.components.speculative.eagle.peagle_draft._peagle_compile_supported(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor
) -> bool

Return whether Inductor’s flex-attention lowering supports these tensors.

nemo_automodel.components.speculative.eagle.peagle_draft._peagle_flex_attention(
q,
k,
v,
block_mask,
scale
)

Run the P-EAGLE flex attention, compiling only when Inductor supports it.

nemo_automodel.components.speculative.eagle.peagle_draft._peagle_flex_attention_compile_failed = False
nemo_automodel.components.speculative.eagle.peagle_draft._peagle_flex_attention_compiled = torch.compile(flex_attention, mode='max-autotune-no-cudagraphs', dynamic=True)
nemo_automodel.components.speculative.eagle.peagle_draft.logger = logging.getLogger(__name__)