nemo_automodel.components.speculative.dflash.domino_core
nemo_automodel.components.speculative.dflash.domino_core
Domino online training wrapper.
Ported from SpecForge’s specforge/core/domino.py (sgl-project/SpecForge#571).
Domino extends the parallel DFlash draft backbone with a lightweight causal
correction head. DFlash drafts a whole block in a single non-causal forward, so
each predicted position is blind to the (drafted) tokens earlier in its own
block. The Domino head fixes that: a GRU encodes a causal state from the block’s
previous tokens, and a low-rank projection of [backbone hidden | GRU state]
produces a logit delta that is added to the parallel base logits.
Training jointly supervises two logits with a base-anchor curriculum::
loss = (1 - lambda_base) * final_loss + lambda_base * base_loss
final_loss is the CE on the Domino-refined logits and base_loss the CE on
the backbone-only base logits. lambda_base decays from its start value to 0
over the first decay_ratio fraction of training, so early steps keep the
parallel backbone strong and later steps let the correction head take over.
DominoTrainerModule reuses the DFlash anchor sampling, noise-block
construction, and block attention mask (it subclasses DFlashTrainerModule);
only the head, the dual-logit loss, and the metrics are Domino-specific. The
Domino head parameters (prefix_gru, embed_proj) live on the DFlash draft
model and are enabled via dflash_config.projector_type='domino'.
Module Contents
Classes
Functions
API
Per-step training outputs for the Domino draft.
loss/accuracy/valid_tokens mirror DFlashStepMetrics so the
DFlash training loop consumes them unchanged. The remaining fields are
diagnostics for the two supervised logits and the curriculum weight.
Bases: DFlashTrainerModule
Domino online training wrapper: DFlash backbone + causal correction head.
First block position that receives the Domino correction.
With shift_label the block predicts anchor+1+k, so position 0 is
already a real next-token prediction; otherwise position 0 is the clean
anchor and is excluded. pure_draft_prefix_len reserves additional
leading positions for the backbone-only (uncorrected) base logits.
Add the GRU-conditioned low-rank correction to the suffix base logits.
Reshape backbone hidden states and gather the block’s previous tokens.
prev_ids[..., k] is the token that precedes block-position k’s
target — the input token at anchor+k when shift_label (the GRU
consumes ground-truth context), else the target sequence itself.
Diagnostics for both heads (acceptance length, base accuracy). No gradient.
Decay-weighted CE on both logits, mixed by the curriculum weight.
Parallel block-wise training forward with the Domino correction head.
Per-block acceptance length: consecutive correct predictions from position 0.
Invalid (masked-out) positions count as correct so they never truncate a block prematurely; the trailing valid mask then zeros their contribution.
Base-anchor curriculum weight at global_step.
lambda_base starts at lambda_start and decays linearly to 0 over the
first decay_ratio fraction of total_steps, then stays at 0. The result
is clamped to [0, 1].