nemo_automodel.components.speculative.dflash.draft_qwen3
nemo_automodel.components.speculative.dflash.draft_qwen3
DFlash draft model (Qwen3-style).
Ported from SpecForge’s specforge/modeling/draft/dflash.py. DFlash drafts a
whole block of block_size tokens in parallel: the block’s first position
holds the real anchor token and the rest are MASK tokens, and the draft
predicts the whole block in a single non-causal forward conditioned on the
target model’s context hidden states.
The draft attention is therefore not causal — a draft block’s queries
attend to (a) the projected target-hidden context strictly before its anchor and
(b) bidirectionally to the other (noise) tokens of the same block. The attention
mask that enforces this is built by the trainer wrapper in
nemo_automodel.components.speculative.dflash.core.
Module Contents
Classes
Functions
API
Bases: Module
Non-causal attention whose keys/values are [context | noise-block].
Queries come from the draft (noise) tokens only; keys and values are the
concatenation of the projected target-hidden context and the noise tokens.
The bidirectional/block structure is supplied entirely by attention_mask.
Bases: GradientCheckpointingLayer
A DFlash decoder block: non-causal attention over [context | noise] + MLP.
Bases: Qwen3PreTrainedModel
DFlash draft model: a small non-causal Qwen3 stack over [context | noise].
Block-parallel speculative decoding: draft a block, verify with the target, accept the matching prefix.
Apply RoPE where queries (draft block) are a suffix of the key positions.
The keys span [context | noise-block] while the queries are only the
noise block, so q is rotated with the trailing q_len slice of the
rotary tables and k with the full table.
Pick num_draft_layers target layers spread across the target’s depth.
Concatenate the selected target layers’ hidden states along the feature dim.
hidden_states follows HF’s output_hidden_states convention where
index 0 is the embedding output, so layer i’s output is at index
i + 1.
Greedy (temperature ~ 0) or temperature sampling over the last dim.