nemo_automodel.components.models.diffusion_gemma.attention_mask
nemo_automodel.components.models.diffusion_gemma.attention_mask
Block-causal training attention mask for diffusion_gemma block diffusion.
This module builds the training decoder attention mask for the
diffusion_gemma block-diffusion model. It is the highest correctness-risk
piece of the SFT path; read the leakage invariant below before changing it.
Layout
At inference the model runs the shared transformer twice: once causally over
the clean prefix to populate a per-layer KV cache (the “encoder” KV), and once
bidirectionally over a single noised canvas block (the “decoder”). Each
decoder layer concatenates [encoder_KV ; canvas_KV] along the key axis
(modeling_diffusion_gemma.py:579-582), so the decoder query attends over a
key axis of length enc_len + canvas_len.
For training we run the whole sequence at once and supervise all response
blocks jointly (joint block-causal block diffusion). The encoder holds the
clean full sequence (prompt + full response); the canvas holds the
noised full response. The training mask therefore has shape
[B, 1, canvas_len, enc_len + canvas_len] and splits column-wise into:
- Left columns
[0, enc_len)— the clean encoder KV. A canvas query in blockimay attend a clean encoder column only if that column belongs to a response block strictly before blocki(offset-block-causal,M_OBC). Prompt columns (encoder positions< prefix_len) are always visible. - Right columns
[enc_len, enc_len + canvas_len)— the noised canvas KV. Block-diagonal (M_BD): canvas blockiattends bidirectionally within blockionly, never to another canvas block.
This mirrors the 3-part BD3LM mask (M_BD / M_OBC / M_BC) in
dllm-zhz/dllm/core/trainers/bd3lm.py adapted to the encoder-KV/canvas
layout. There is no M_BC term: in BD3LM M_BC is block-causal attention
within the clean (x_0) half, but here the clean tokens live only in the
encoder columns (there is no clean-canvas query half), so it does not apply.
Leakage invariant (THE correctness property)
M_OBC uses a strict block_q > block_kv comparison. A canvas query in
block i MUST be masked against the clean encoder column at response-relative
position i * block_size (the first token of its own block) and every later
clean position. Using >= instead of > is silent total leakage: the
canvas would see the clean answer for the very tokens it is being trained to
denoise, the loss would collapse, and the model would learn nothing useful.
The unit test tests/.../test_diffusion_gemma_mask.py asserts exactly this
boundary and is the gate for the rest of the SFT work.
Module Contents
Functions
API
Response-relative block index for each of num_positions positions.
Convert a boolean keep-mask to an additive mask (0 / -inf).
Build the block-causal training mask and its sliding-window variant.
The canvas (decoder query axis) has length response_length. The key axis
is [encoder_KV (enc_len) ; canvas_KV (response_length)], so the returned
masks have shape [B, 1, response_length, enc_len + response_length].
Parameters:
Per-example prompt length(s) in the encoder, i.e. the
number of leading clean encoder positions that are prompt (always
attendable). An int is broadcast to all examples; a 1-D tensor
of shape [B] gives per-example prefixes. The response occupies
encoder positions [prefix_length, prefix_length + response_length).
Canvas length (number of noised response positions).
Total encoder key length (prefix + response and any tail
padding columns). Must satisfy enc_len >= max(prefix) + response_length.
Diffusion block size (canvas_length; 256 for the ckpt).
If given, the sliding variant additionally restricts
attention to key positions within sliding_window absolute
positions of the query (see module docstring for position-id
convention). If None the sliding variant equals the full mask.
Batch dimension. Required when prefix_lengths is an int;
inferred from the tensor otherwise.
Device for the returned tensors.
If None (default) return boolean keep masks (True =
attend). If a floating dtype, return an additive mask
(0 where attended, -inf where masked) ready to add to
attention scores.
Returns: torch.Tensor
Tuple (mask_full, mask_sliding):