nemo_automodel.components.attention.dflash_mask
nemo_automodel.components.attention.dflash_mask
DFlash sparse-attention masks (SDPA + FlexAttention).
Builds the DFlash block-diagonal attention masks (paper ยง4.2) so that multi-anchor DFlash training (up to ~512 anchors per sequence โ paper Appendix A.1) is tractable in memory.
KV layout: [ context (S tokens) | block_0 | block_1 | ... | block_{N-1} ]
Q layout: [ block_0 | block_1 | ... | block_{N-1} ]
Each query in block b attends to:
- context positions strictly less than
anchor[b](causal-style prefix) - its own blockโs noise positions (bidirectional in-block)
- nothing else โ other blocks are invisible
The context is never queried from (the target LM is frozen, we only need its hidden states), so omitting it from Q halves the attention compute vs. including context positions in Q.
Module Contents
Functions
Data
API
Lazy-initialise a compiled create_block_mask and cache it.
Build a sparse FlexAttention :class:BlockMask for DFlash training.
See module docstring for the mask semantics. The returned BlockMask is
consumed directly by transformersโ flex_attention backend when
_attn_implementation="flex_attention" is set on the draft model โ pass
it via the attention_mask kwarg.
Parameters:
[B, N] anchor positions (long).
[B, N] valid-anchor mask (bool).
context length.
block size.
torch device.
Cache and reuse a torch.compileโd create_block_mask
across calls (default True). Set to False when running on PyTorch
builds that hit Inductor errors during compile.
Returns: 'BlockMask'
class:torch.nn.attention.flex_attention.BlockMask.
Build a dense additive attention mask for the SDPA backend.
Parameters:
[B, N] anchor positions per sample (long).
[B, N] per-sample valid-anchor mask (bool).
context length S.
block size.
torch device.
dtype for the additive mask (typically the model dtype).
Returns: torch.Tensor
[B, 1, N*block_size, S + N*block_size] float tensor: 0 at