nemo_automodel.components.speculative.dflash.draft_qwen3

View as Markdown

DFlash draft model (Qwen3-style).

Ported from SpecForge’s specforge/modeling/draft/dflash.py. DFlash drafts a whole block of block_size tokens in parallel: the block’s first position holds the real anchor token and the rest are MASK tokens, and the draft predicts the whole block in a single non-causal forward conditioned on the target model’s context hidden states.

The draft attention is therefore not causal — a draft block’s queries attend to (a) the projected target-hidden context strictly before its anchor and (b) bidirectionally to the other (noise) tokens of the same block. The attention mask that enforces this is built by the trainer wrapper in nemo_automodel.components.speculative.dflash.core.

Module Contents

Classes

NameDescription
Qwen3DFlashAttentionNon-causal attention whose keys/values are [context | noise-block].
Qwen3DFlashDecoderLayerA DFlash decoder block: non-causal attention over [context | noise] + MLP.
Qwen3DFlashDraftModelDFlash draft model: a small non-causal Qwen3 stack over [context | noise].

Functions

NameDescription
apply_rotary_pos_embApply RoPE where queries (draft block) are a suffix of the key positions.
build_target_layer_idsPick num_draft_layers target layers spread across the target’s depth.
extract_context_featureConcatenate the selected target layers’ hidden states along the feature dim.
sampleGreedy (temperature ~ 0) or temperature sampling over the last dim.

API

class nemo_automodel.components.speculative.dflash.draft_qwen3.Qwen3DFlashAttention(
config: transformers.models.qwen3.configuration_qwen3.Qwen3Config,
layer_idx: int
)

Bases: Module

Non-causal attention whose keys/values are [context | noise-block].

Queries come from the draft (noise) tokens only; keys and values are the concatenation of the projected target-hidden context and the noise tokens. The bidirectional/block structure is supplied entirely by attention_mask.

attention_dropout
= config.attention_dropout
head_dim
k_norm
k_proj
num_key_value_groups
o_proj
q_norm
q_proj
scaling
= self.head_dim ** -0.5
v_proj
nemo_automodel.components.speculative.dflash.draft_qwen3.Qwen3DFlashAttention.forward(
hidden_states: torch.Tensor,
target_hidden: torch.Tensor,
position_embeddings: typing.Tuple[torch.Tensor, torch.Tensor],
attention_mask: typing.Optional[torch.Tensor],
past_key_values: typing.Optional[transformers.cache_utils.Cache] = None,
cache_position: typing.Optional[torch.LongTensor] = None,
kwargs = {}
) -> typing.Tuple[torch.Tensor, typing.Optional[torch.Tensor]]
class nemo_automodel.components.speculative.dflash.draft_qwen3.Qwen3DFlashDecoderLayer(
config: transformers.models.qwen3.configuration_qwen3.Qwen3Config,
layer_idx: int
)

Bases: GradientCheckpointingLayer

A DFlash decoder block: non-causal attention over [context | noise] + MLP.

hidden_size
= config.hidden_size
input_layernorm
mlp
= Qwen3MLP(config)
post_attention_layernorm
self_attn
nemo_automodel.components.speculative.dflash.draft_qwen3.Qwen3DFlashDecoderLayer.forward(
target_hidden: typing.Optional[torch.Tensor] = None,
hidden_states: typing.Optional[torch.Tensor] = None,
attention_mask: typing.Optional[torch.Tensor] = None,
position_ids: typing.Optional[torch.LongTensor] = None,
past_key_value: typing.Optional[transformers.cache_utils.Cache] = None,
use_cache: typing.Optional[bool] = False,
cache_position: typing.Optional[torch.LongTensor] = None,
position_embeddings: typing.Optional[typing.Tuple[torch.Tensor, torch.Tensor]] = None,
kwargs = {}
) -> torch.Tensor
class nemo_automodel.components.speculative.dflash.draft_qwen3.Qwen3DFlashDraftModel(
config
)

Bases: Qwen3PreTrainedModel

DFlash draft model: a small non-causal Qwen3 stack over [context | noise].

_no_split_modules
= ['Qwen3DFlashDecoderLayer']
block_size
= config.block_size
fc
hidden_norm
layers
mask_token_id
= dflash_config.get('mask_token_id', None)
norm
rotary_emb
= Qwen3RotaryEmbedding(config)
target_layer_ids
nemo_automodel.components.speculative.dflash.draft_qwen3.Qwen3DFlashDraftModel.forward(
position_ids: torch.LongTensor,
attention_mask: typing.Optional[torch.Tensor] = None,
noise_embedding: typing.Optional[torch.Tensor] = None,
target_hidden: typing.Optional[torch.Tensor] = None,
past_key_values: typing.Optional[transformers.cache_utils.Cache] = None,
use_cache: bool = False,
kwargs = {}
) -> torch.Tensor
nemo_automodel.components.speculative.dflash.draft_qwen3.Qwen3DFlashDraftModel.spec_generate(
target: torch.nn.Module,
input_ids: torch.LongTensor,
max_new_tokens: int,
stop_token_ids: typing.Optional[list[int]],
temperature: float
) -> torch.LongTensor

Block-parallel speculative decoding: draft a block, verify with the target, accept the matching prefix.

nemo_automodel.components.speculative.dflash.draft_qwen3.apply_rotary_pos_emb(
q,
k,
cos,
sin,
unsqueeze_dim = 1
)

Apply RoPE where queries (draft block) are a suffix of the key positions.

The keys span [context | noise-block] while the queries are only the noise block, so q is rotated with the trailing q_len slice of the rotary tables and k with the full table.

nemo_automodel.components.speculative.dflash.draft_qwen3.build_target_layer_ids(
num_target_layers: int,
num_draft_layers: int
) -> list[int]

Pick num_draft_layers target layers spread across the target’s depth.

nemo_automodel.components.speculative.dflash.draft_qwen3.extract_context_feature(
hidden_states: list[torch.Tensor],
layer_ids: list[int]
) -> torch.Tensor

Concatenate the selected target layers’ hidden states along the feature dim.

hidden_states follows HF’s output_hidden_states convention where index 0 is the embedding output, so layer i’s output is at index i + 1.

nemo_automodel.components.speculative.dflash.draft_qwen3.sample(
logits: torch.Tensor,
temperature: float = 0.0
) -> torch.Tensor

Greedy (temperature ~ 0) or temperature sampling over the last dim.