nemo_automodel.components.speculative.dflash.core

View as Markdown

DFlash online training wrapper.

Ported from SpecForge’s specforge/core/dflash.py. DFlashTrainerModule samples a set of anchor positions per sequence, builds one parallel draft block per anchor (the block’s first token is the real anchor token, the rest are MASK), runs the draft model under a bespoke block attention mask, and computes a block-wise cross-entropy loss against the ground-truth continuation of each anchor.

Module Contents

Classes

NameDescription
DFlashStepMetricsPer-step training outputs for the DFlash draft.
DFlashTrainerModuleDFlash online training wrapper with block-wise CE loss.
NoValidAnchorsErrorRaised when a batch has no sample long enough to form a DFlash block.

API

class nemo_automodel.components.speculative.dflash.core.DFlashStepMetrics(
loss: torch.Tensor,
accuracy: torch.Tensor,
valid_tokens: torch.Tensor
)
Dataclass

Per-step training outputs for the DFlash draft.

accuracy
Tensor
loss
Tensor
valid_tokens
Tensor
class nemo_automodel.components.speculative.dflash.core.DFlashTrainerModule(
draft_model: nemo_automodel.components.speculative.dflash.draft_qwen3.Qwen3DFlashDraftModel,
target_lm_head: torch.nn.Module,
target_embed_tokens: torch.nn.Module,
mask_token_id: int,
block_size: int = 16,
attention_backend: str = 'flex_attention',
num_anchors: int = 512,
loss_decay_gamma: typing.Optional[float] = None
)

Bases: Module

DFlash online training wrapper with block-wise CE loss.

loss_fn
nemo_automodel.components.speculative.dflash.core.DFlashTrainerModule._create_noise_embed(
input_ids,
anchor_positions,
block_keep_mask
)

Embed each block as [anchor_token, MASK, MASK, ...] (invalid blocks all MASK).

nemo_automodel.components.speculative.dflash.core.DFlashTrainerModule._create_position_ids(
anchor_positions: torch.Tensor
) -> torch.Tensor

Absolute position ids for the parallel draft blocks (anchor + offset).

nemo_automodel.components.speculative.dflash.core.DFlashTrainerModule._sample_anchor_positions(
seq_len: int,
loss_mask: torch.Tensor,
device: torch.device
) -> typing.Tuple[torch.Tensor, torch.Tensor]

Randomly sample anchor positions per sample; returns (anchors, keep_mask).

nemo_automodel.components.speculative.dflash.core.DFlashTrainerModule.forward(
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
loss_mask: torch.Tensor
) -> nemo_automodel.components.speculative.dflash.core.DFlashStepMetrics

Parallel block-wise training forward pass.

class nemo_automodel.components.speculative.dflash.core.NoValidAnchorsError()

Bases: ValueError

Raised when a batch has no sample long enough to form a DFlash block.

A DFlash anchor needs at least block_size + 1 supervised tokens (the anchor plus its block). Datasets always contain some short conversations; the training loop catches this and skips the offending micro-batch rather than aborting the run.