nemo_automodel.components.speculative.dflash.domino_core

View as Markdown

Domino online training wrapper.

Ported from SpecForge’s specforge/core/domino.py (sgl-project/SpecForge#571).

Domino extends the parallel DFlash draft backbone with a lightweight causal correction head. DFlash drafts a whole block in a single non-causal forward, so each predicted position is blind to the (drafted) tokens earlier in its own block. The Domino head fixes that: a GRU encodes a causal state from the block’s previous tokens, and a low-rank projection of [backbone hidden | GRU state] produces a logit delta that is added to the parallel base logits.

Training jointly supervises two logits with a base-anchor curriculum::

loss = (1 - lambda_base) * final_loss + lambda_base * base_loss

final_loss is the CE on the Domino-refined logits and base_loss the CE on the backbone-only base logits. lambda_base decays from its start value to 0 over the first decay_ratio fraction of training, so early steps keep the parallel backbone strong and later steps let the correction head take over.

DominoTrainerModule reuses the DFlash anchor sampling, noise-block construction, and block attention mask (it subclasses DFlashTrainerModule); only the head, the dual-logit loss, and the metrics are Domino-specific. The Domino head parameters (prefix_gru, embed_proj) live on the DFlash draft model and are enabled via dflash_config.projector_type='domino'.

Module Contents

Classes

NameDescription
DominoStepMetricsPer-step training outputs for the Domino draft.
DominoTrainerModuleDomino online training wrapper: DFlash backbone + causal correction head.

Functions

NameDescription
compute_accept_lenPer-block acceptance length: consecutive correct predictions from position 0.
get_lambda_baseBase-anchor curriculum weight at global_step.

API

class nemo_automodel.components.speculative.dflash.domino_core.DominoStepMetrics(
loss: torch.Tensor,
accuracy: torch.Tensor,
valid_tokens: torch.Tensor,
final_loss: torch.Tensor,
base_loss: torch.Tensor,
base_accuracy: torch.Tensor,
accept_len: torch.Tensor,
base_accept_len: torch.Tensor,
lambda_base: torch.Tensor
)
Dataclass

Per-step training outputs for the Domino draft.

loss/accuracy/valid_tokens mirror DFlashStepMetrics so the DFlash training loop consumes them unchanged. The remaining fields are diagnostics for the two supervised logits and the curriculum weight.

accept_len
Tensor
accuracy
Tensor
base_accept_len
Tensor
base_accuracy
Tensor
base_loss
Tensor
final_loss
Tensor
lambda_base
Tensor
loss
Tensor
valid_tokens
Tensor
class nemo_automodel.components.speculative.dflash.domino_core.DominoTrainerModule(
draft_model: nemo_automodel.components.speculative.dflash.draft_qwen3.Qwen3DFlashDraftModel,
target_lm_head: torch.nn.Module,
target_embed_tokens: torch.nn.Module,
mask_token_id: int,
block_size: int = 16,
attention_backend: str = 'flex_attention',
num_anchors: int = 512,
loss_decay_gamma: typing.Optional[float] = None,
shift_label: bool = False
)

Bases: DFlashTrainerModule

Domino online training wrapper: DFlash backbone + causal correction head.

_suffix_start
int

First block position that receives the Domino correction.

With shift_label the block predicts anchor+1+k, so position 0 is already a real next-token prediction; otherwise position 0 is the clean anchor and is excluded. pure_draft_prefix_len reserves additional leading positions for the backbone-only (uncorrected) base logits.

nemo_automodel.components.speculative.dflash.domino_core.DominoTrainerModule._apply_domino_head(
base_logits4d: torch.Tensor,
hidden4d: torch.Tensor,
prev_ids: torch.Tensor,
target_ids: torch.Tensor
) -> torch.Tensor

Add the GRU-conditioned low-rank correction to the suffix base logits.

nemo_automodel.components.speculative.dflash.domino_core.DominoTrainerModule._build_domino_head_inputs(
input_ids: torch.Tensor,
anchor_positions: torch.Tensor,
target_ids: torch.Tensor,
output_hidden: torch.Tensor
) -> typing.Tuple[torch.Tensor, torch.Tensor]

Reshape backbone hidden states and gather the block’s previous tokens.

prev_ids[..., k] is the token that precedes block-position k’s target — the input token at anchor+k when shift_label (the GRU consumes ground-truth context), else the target sequence itself.

nemo_automodel.components.speculative.dflash.domino_core.DominoTrainerModule._compute_extra_metrics(
pred_ids: torch.Tensor,
flat_base_logits: torch.Tensor,
flat_targets: torch.Tensor,
binary_eval_mask: torch.Tensor,
actual_token_count: torch.Tensor,
target_ids: torch.Tensor,
eval_weight_mask: torch.Tensor,
final_loss: torch.Tensor,
base_loss: torch.Tensor,
lambda_base: float
) -> typing.Dict[str, torch.Tensor]

Diagnostics for both heads (acceptance length, base accuracy). No gradient.

nemo_automodel.components.speculative.dflash.domino_core.DominoTrainerModule._compute_weighted_losses(
final_logits: torch.Tensor,
base_logits: torch.Tensor,
target_ids: torch.Tensor,
weight_mask: torch.Tensor,
lambda_base: float
) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Decay-weighted CE on both logits, mixed by the curriculum weight.

nemo_automodel.components.speculative.dflash.domino_core.DominoTrainerModule.forward(
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
loss_mask: torch.Tensor,
lambda_base: float = 0.0
) -> nemo_automodel.components.speculative.dflash.domino_core.DominoStepMetrics

Parallel block-wise training forward with the Domino correction head.

nemo_automodel.components.speculative.dflash.domino_core.compute_accept_len(
pred_ids_4d: torch.Tensor,
target_ids_4d: torch.Tensor,
valid_mask_4d: torch.Tensor
) -> torch.Tensor

Per-block acceptance length: consecutive correct predictions from position 0.

Invalid (masked-out) positions count as correct so they never truncate a block prematurely; the trailing valid mask then zeros their contribution.

nemo_automodel.components.speculative.dflash.domino_core.get_lambda_base(
global_step: int,
total_steps: int,
lambda_start: float = 1.0,
decay_ratio: float = 0.5
) -> float

Base-anchor curriculum weight at global_step.

lambda_base starts at lambda_start and decays linearly to 0 over the first decay_ratio fraction of total_steps, then stays at 0. The result is clamped to [0, 1].