nemo_automodel.components.speculative.eagle.draft_gpt_oss

View as Markdown

EAGLE-3 draft model for gpt-oss targets (GptOssForCausalLM).

gpt-oss is a Mixture-of-Experts decoder whose target backbone differs from a Llama-style dense LLM in three ways: alternating sliding-window / full attention layers, learnable attention sinks, and YaRN-scaled RoPE. The first two never reach the EAGLE-3 draft: the draft is a single from-scratch decoder layer that consumes only the post-block auxiliary hidden states emitted by the frozen target (via register_forward_hook) and re-projects its own Q/K, so it never sees the target’s experts, sinks, or sliding mask — structurally it is the same Llama-style dense draft used for every other registry entry.

RoPE is the exception, and it must match the target. During speculative decoding the draft runs at the same token positions as the target, and gpt-oss is a long-context model whose rotary frequencies are reshaped by YaRN (NTK-by-parts with a concentration scale, base 150000, factor=32 to extend 4096 -> 131072). A draft trained with a different rotary schedule is positionally inconsistent with the target: it may converge at the short context used during training, but its acceptance rate collapses at the long positions gpt-oss is built for, because its notion of “position p” diverges from the target’s. (SpecForge trains the gpt-oss draft from a plain model_type: llama config with standard RoPE; that is a latent bug masked by short-context training, not a recipe to copy.)

The shared LlamaRotaryEmbedding cannot represent YaRN — it implements only {"default", "llama3"} and silently falls back to the llama3 NTK schedule for rope_type="yarn". So this draft swaps in :class:GptOssDraftRotaryEmbedding, which reproduces gpt-oss’s exact YaRN inv_freq and concentration (reusing the target’s own components/models/gpt_oss/rope_utils.RotaryEmbedding) but returns (cos, sin) in the duplicated [..., head_dim] layout. gpt-oss’s interleaved apply_rotary_emb and the draft’s rotate_half-based apply_rotary_pos_emb are algebraically identical under that layout, so the draft’s rotation is bit-faithful to the target’s.

Everything else (GQA, attention/MLP bias, RMSNorm, the EAGLE-3 TTT cache attention, the fc projection, the draft lm_head and vocab mapping) is inherited unchanged from LlamaEagle3DraftModel. The on-disk state-dict layout and the saved architectures: ["LlamaEagle3DraftModel"] string are unchanged, so checkpoints trained here load into SGLang exactly like the Llama draft.

Module Contents

Classes

NameDescription
GptOssDraftRotaryEmbeddinggpt-oss YaRN RoPE exposed through the LlamaRotaryEmbedding (cos, sin) API.
GptOssEagle3DraftModelEAGLE-3 draft model for gpt-oss targets.

API

class nemo_automodel.components.speculative.eagle.draft_gpt_oss.GptOssDraftRotaryEmbedding(
config: transformers.PretrainedConfig
)

Bases: Module

gpt-oss YaRN RoPE exposed through the LlamaRotaryEmbedding (cos, sin) API.

Produces the same per-position rotary frequencies as the gpt-oss target (YaRN NTK-by-parts with concentration), reusing the target’s own RotaryEmbedding so the YaRN math lives in one place. Unlike that class — which builds cos/sin of size rotary_dim // 2 and applies them with an interleaved split — this returns them duplicated to size head_dim so the draft attention’s rotate_half-based apply_rotary_pos_emb performs the identical rotation. The values in position_ids are honored (EAGLE-3 TTT passes arange(seq_len) + step_idx).

_rope
nemo_automodel.components.speculative.eagle.draft_gpt_oss.GptOssDraftRotaryEmbedding.forward(
x: torch.Tensor,
position_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]
class nemo_automodel.components.speculative.eagle.draft_gpt_oss.GptOssEagle3DraftModel(
config: transformers.PretrainedConfig
)

Bases: LlamaEagle3DraftModel

EAGLE-3 draft model for gpt-oss targets.

Identical to :class:LlamaEagle3DraftModel except that the single draft layer’s rotary embedding is replaced with :class:GptOssDraftRotaryEmbedding so the draft reproduces gpt-oss’s YaRN RoPE instead of the YaRN-incapable LlamaRotaryEmbedding.