> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.speculative.dflash.domino_core

Domino online training wrapper.

Ported from SpecForge's `specforge/core/domino.py` (sgl-project/SpecForge#571).

Domino extends the parallel DFlash draft backbone with a lightweight *causal*
correction head. DFlash drafts a whole block in a single non-causal forward, so
each predicted position is blind to the (drafted) tokens earlier in its own
block. The Domino head fixes that: a GRU encodes a causal state from the block's
previous tokens, and a low-rank projection of `[backbone hidden | GRU state]`
produces a logit delta that is *added* to the parallel base logits.

Training jointly supervises two logits with a base-anchor curriculum::

loss = (1 - lambda\_base) \* final\_loss + lambda\_base \* base\_loss

`final_loss` is the CE on the Domino-refined logits and `base_loss` the CE on
the backbone-only base logits. `lambda_base` decays from its start value to 0
over the first `decay_ratio` fraction of training, so early steps keep the
parallel backbone strong and later steps let the correction head take over.

`DominoTrainerModule` reuses the DFlash anchor sampling, noise-block
construction, and block attention mask (it subclasses `DFlashTrainerModule`);
only the head, the dual-logit loss, and the metrics are Domino-specific. The
Domino head parameters (`prefix_gru`, `embed_proj`) live on the DFlash draft
model and are enabled via `dflash_config.projector_type='domino'`.

## Module Contents

### Classes

| Name                                                                                                   | Description                                                               |
| ------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------- |
| [`DominoStepMetrics`](#nemo_automodel-components-speculative-dflash-domino_core-DominoStepMetrics)     | Per-step training outputs for the Domino draft.                           |
| [`DominoTrainerModule`](#nemo_automodel-components-speculative-dflash-domino_core-DominoTrainerModule) | Domino online training wrapper: DFlash backbone + causal correction head. |

### Functions

| Name                                                                                                 | Description                                                                   |
| ---------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------- |
| [`compute_accept_len`](#nemo_automodel-components-speculative-dflash-domino_core-compute_accept_len) | Per-block acceptance length: consecutive correct predictions from position 0. |
| [`get_lambda_base`](#nemo_automodel-components-speculative-dflash-domino_core-get_lambda_base)       | Base-anchor curriculum weight at `global_step`.                               |

### API

```python
class nemo_automodel.components.speculative.dflash.domino_core.DominoStepMetrics(
    loss: torch.Tensor,
    accuracy: torch.Tensor,
    valid_tokens: torch.Tensor,
    final_loss: torch.Tensor,
    base_loss: torch.Tensor,
    base_accuracy: torch.Tensor,
    accept_len: torch.Tensor,
    base_accept_len: torch.Tensor,
    lambda_base: torch.Tensor
)
```

Dataclass

Per-step training outputs for the Domino draft.

`loss`/`accuracy`/`valid_tokens` mirror `DFlashStepMetrics` so the
DFlash training loop consumes them unchanged. The remaining fields are
diagnostics for the two supervised logits and the curriculum weight.

```python
class nemo_automodel.components.speculative.dflash.domino_core.DominoTrainerModule(
    draft_model: nemo_automodel.components.speculative.dflash.draft_qwen3.Qwen3DFlashDraftModel,
    target_lm_head: torch.nn.Module,
    target_embed_tokens: torch.nn.Module,
    mask_token_id: int,
    block_size: int = 16,
    attention_backend: str = 'flex_attention',
    num_anchors: int = 512,
    loss_decay_gamma: typing.Optional[float] = None,
    shift_label: bool = False
)
```

**Bases:** [DFlashTrainerModule](/nemo-automodel/nemo_automodel/components/speculative/dflash/core#nemo_automodel-components-speculative-dflash-core-DFlashTrainerModule)

Domino online training wrapper: DFlash backbone + causal correction head.

First block position that receives the Domino correction.

With `shift_label` the block predicts `anchor+1+k`, so position 0 is
already a real next-token prediction; otherwise position 0 is the clean
anchor and is excluded. `pure_draft_prefix_len` reserves additional
leading positions for the backbone-only (uncorrected) base logits.

```python
nemo_automodel.components.speculative.dflash.domino_core.DominoTrainerModule._apply_domino_head(
    base_logits4d: torch.Tensor,
    hidden4d: torch.Tensor,
    prev_ids: torch.Tensor,
    target_ids: torch.Tensor
) -> torch.Tensor
```

Add the GRU-conditioned low-rank correction to the suffix base logits.

```python
nemo_automodel.components.speculative.dflash.domino_core.DominoTrainerModule._build_domino_head_inputs(
    input_ids: torch.Tensor,
    anchor_positions: torch.Tensor,
    target_ids: torch.Tensor,
    output_hidden: torch.Tensor
) -> typing.Tuple[torch.Tensor, torch.Tensor]
```

Reshape backbone hidden states and gather the block's previous tokens.

`prev_ids[..., k]` is the token that precedes block-position `k`'s
target -- the input token at `anchor+k` when `shift_label` (the GRU
consumes ground-truth context), else the target sequence itself.

```python
nemo_automodel.components.speculative.dflash.domino_core.DominoTrainerModule._compute_extra_metrics(
    pred_ids: torch.Tensor,
    flat_base_logits: torch.Tensor,
    flat_targets: torch.Tensor,
    binary_eval_mask: torch.Tensor,
    actual_token_count: torch.Tensor,
    target_ids: torch.Tensor,
    eval_weight_mask: torch.Tensor,
    final_loss: torch.Tensor,
    base_loss: torch.Tensor,
    lambda_base: float
) -> typing.Dict[str, torch.Tensor]
```

Diagnostics for both heads (acceptance length, base accuracy). No gradient.

```python
nemo_automodel.components.speculative.dflash.domino_core.DominoTrainerModule._compute_weighted_losses(
    final_logits: torch.Tensor,
    base_logits: torch.Tensor,
    target_ids: torch.Tensor,
    weight_mask: torch.Tensor,
    lambda_base: float
) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
```

Decay-weighted CE on both logits, mixed by the curriculum weight.

```python
nemo_automodel.components.speculative.dflash.domino_core.DominoTrainerModule.forward(
    input_ids: torch.Tensor,
    hidden_states: torch.Tensor,
    loss_mask: torch.Tensor,
    lambda_base: float = 0.0
) -> nemo_automodel.components.speculative.dflash.domino_core.DominoStepMetrics
```

Parallel block-wise training forward with the Domino correction head.

```python
nemo_automodel.components.speculative.dflash.domino_core.compute_accept_len(
    pred_ids_4d: torch.Tensor,
    target_ids_4d: torch.Tensor,
    valid_mask_4d: torch.Tensor
) -> torch.Tensor
```

Per-block acceptance length: consecutive correct predictions from position 0.

Invalid (masked-out) positions count as correct so they never truncate a block
prematurely; the trailing valid mask then zeros their contribution.

```python
nemo_automodel.components.speculative.dflash.domino_core.get_lambda_base(
    global_step: int,
    total_steps: int,
    lambda_start: float = 1.0,
    decay_ratio: float = 0.5
) -> float
```

Base-anchor curriculum weight at `global_step`.

`lambda_base` starts at `lambda_start` and decays linearly to 0 over the
first `decay_ratio` fraction of `total_steps`, then stays at 0. The result
is clamped to `[0, 1]`.