nemo_automodel.components.models.diffusion_gemma.attention_mask

View as Markdown

Block-causal training attention mask for diffusion_gemma block diffusion.

This module builds the training decoder attention mask for the diffusion_gemma block-diffusion model. It is the highest correctness-risk piece of the SFT path; read the leakage invariant below before changing it.

Layout

At inference the model runs the shared transformer twice: once causally over the clean prefix to populate a per-layer KV cache (the “encoder” KV), and once bidirectionally over a single noised canvas block (the “decoder”). Each decoder layer concatenates [encoder_KV ; canvas_KV] along the key axis (modeling_diffusion_gemma.py:579-582), so the decoder query attends over a key axis of length enc_len + canvas_len.

For training we run the whole sequence at once and supervise all response blocks jointly (joint block-causal block diffusion). The encoder holds the clean full sequence (prompt + full response); the canvas holds the noised full response. The training mask therefore has shape [B, 1, canvas_len, enc_len + canvas_len] and splits column-wise into:

  • Left columns [0, enc_len) — the clean encoder KV. A canvas query in block i may attend a clean encoder column only if that column belongs to a response block strictly before block i (offset-block-causal, M_OBC). Prompt columns (encoder positions < prefix_len) are always visible.
  • Right columns [enc_len, enc_len + canvas_len) — the noised canvas KV. Block-diagonal (M_BD): canvas block i attends bidirectionally within block i only, never to another canvas block.

This mirrors the 3-part BD3LM mask (M_BD / M_OBC / M_BC) in dllm-zhz/dllm/core/trainers/bd3lm.py adapted to the encoder-KV/canvas layout. There is no M_BC term: in BD3LM M_BC is block-causal attention within the clean (x_0) half, but here the clean tokens live only in the encoder columns (there is no clean-canvas query half), so it does not apply.

Leakage invariant (THE correctness property)

M_OBC uses a strict block_q > block_kv comparison. A canvas query in block i MUST be masked against the clean encoder column at response-relative position i * block_size (the first token of its own block) and every later clean position. Using >= instead of > is silent total leakage: the canvas would see the clean answer for the very tokens it is being trained to denoise, the loss would collapse, and the model would learn nothing useful. The unit test tests/.../test_diffusion_gemma_mask.py asserts exactly this boundary and is the gate for the rest of the SFT work.

Module Contents

Functions

NameDescription
_block_idsResponse-relative block index for each of num_positions positions.
_to_additiveConvert a boolean keep-mask to an additive mask (0 / -inf).
build_block_diffusion_training_maskBuild the block-causal training mask and its sliding-window variant.

API

nemo_automodel.components.models.diffusion_gemma.attention_mask._block_ids(
num_positions: int,
block_size: int,
device: torch.device
) -> torch.Tensor

Response-relative block index for each of num_positions positions.

nemo_automodel.components.models.diffusion_gemma.attention_mask._to_additive(
keep: torch.Tensor,
dtype: torch.dtype
) -> torch.Tensor

Convert a boolean keep-mask to an additive mask (0 / -inf).

nemo_automodel.components.models.diffusion_gemma.attention_mask.build_block_diffusion_training_mask(
prefix_lengths: torch.Tensor | int,
response_length: int,
enc_len: int,
block_size: int,
sliding_window: int | None = None,
batch_size: int | None = None,
device: torch.device | str = 'cpu',
dtype: torch.dtype | None = None
) -> tuple[torch.Tensor, torch.Tensor]

Build the block-causal training mask and its sliding-window variant.

The canvas (decoder query axis) has length response_length. The key axis is [encoder_KV (enc_len) ; canvas_KV (response_length)], so the returned masks have shape [B, 1, response_length, enc_len + response_length].

Parameters:

prefix_lengths
torch.Tensor | int

Per-example prompt length(s) in the encoder, i.e. the number of leading clean encoder positions that are prompt (always attendable). An int is broadcast to all examples; a 1-D tensor of shape [B] gives per-example prefixes. The response occupies encoder positions [prefix_length, prefix_length + response_length).

response_length
int

Canvas length (number of noised response positions).

enc_len
int

Total encoder key length (prefix + response and any tail padding columns). Must satisfy enc_len >= max(prefix) + response_length.

block_size
int

Diffusion block size (canvas_length; 256 for the ckpt).

sliding_window
int | NoneDefaults to None

If given, the sliding variant additionally restricts attention to key positions within sliding_window absolute positions of the query (see module docstring for position-id convention). If None the sliding variant equals the full mask.

batch_size
int | NoneDefaults to None

Batch dimension. Required when prefix_lengths is an int; inferred from the tensor otherwise.

device
torch.device | strDefaults to 'cpu'

Device for the returned tensors.

dtype
torch.dtype | NoneDefaults to None

If None (default) return boolean keep masks (True = attend). If a floating dtype, return an additive mask (0 where attended, -inf where masked) ready to add to attention scores.

Returns: torch.Tensor

Tuple (mask_full, mask_sliding):