nemo_automodel.components.attention.dflash_mask

View as Markdown

DFlash sparse-attention masks (SDPA + FlexAttention).

Builds the DFlash block-diagonal attention masks (paper ยง4.2) so that multi-anchor DFlash training (up to ~512 anchors per sequence โ€” paper Appendix A.1) is tractable in memory.

KV layout: [ context (S tokens) | block_0 | block_1 | ... | block_{N-1} ] Q layout: [ block_0 | block_1 | ... | block_{N-1} ]

Each query in block b attends to:

  1. context positions strictly less than anchor[b] (causal-style prefix)
  2. its own blockโ€™s noise positions (bidirectional in-block)
  3. nothing else โ€” other blocks are invisible

The context is never queried from (the target LM is frozen, we only need its hidden states), so omitting it from Q halves the attention compute vs. including context positions in Q.

Module Contents

Functions

NameDescription
_get_compiled_create_block_maskLazy-initialise a compiled create_block_mask and cache it.
create_dflash_block_maskBuild a sparse FlexAttention :class:BlockMask for DFlash training.
create_dflash_sdpa_maskBuild a dense additive attention mask for the SDPA backend.

Data

_compiled_create_block_mask

API

nemo_automodel.components.attention.dflash_mask._get_compiled_create_block_mask()

Lazy-initialise a compiled create_block_mask and cache it.

nemo_automodel.components.attention.dflash_mask.create_dflash_block_mask(
anchor_positions: torch.Tensor,
block_keep_mask: torch.Tensor,
ctx_len: int,
block_size: int,
device: torch.device,
use_compile: bool = True
) -> 'BlockMask'

Build a sparse FlexAttention :class:BlockMask for DFlash training.

See module docstring for the mask semantics. The returned BlockMask is consumed directly by transformersโ€™ flex_attention backend when _attn_implementation="flex_attention" is set on the draft model โ€” pass it via the attention_mask kwarg.

Parameters:

anchor_positions
torch.Tensor

[B, N] anchor positions (long).

block_keep_mask
torch.Tensor

[B, N] valid-anchor mask (bool).

ctx_len
int

context length.

block_size
int

block size.

device
torch.device

torch device.

use_compile
boolDefaults to True

Cache and reuse a torch.compileโ€™d create_block_mask across calls (default True). Set to False when running on PyTorch builds that hit Inductor errors during compile.

Returns: 'BlockMask'

class:torch.nn.attention.flex_attention.BlockMask.

nemo_automodel.components.attention.dflash_mask.create_dflash_sdpa_mask(
anchor_positions: torch.Tensor,
block_keep_mask: torch.Tensor,
ctx_len: int,
block_size: int,
device: torch.device,
dtype: torch.dtype
) -> torch.Tensor

Build a dense additive attention mask for the SDPA backend.

Parameters:

anchor_positions
torch.Tensor

[B, N] anchor positions per sample (long).

block_keep_mask
torch.Tensor

[B, N] per-sample valid-anchor mask (bool).

ctx_len
int

context length S.

block_size
int

block size.

device
torch.device

torch device.

dtype
torch.dtype

dtype for the additive mask (typically the model dtype).

Returns: torch.Tensor

[B, 1, N*block_size, S + N*block_size] float tensor: 0 at

nemo_automodel.components.attention.dflash_mask._compiled_create_block_mask = None