nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn

View as Markdown

Context-parallel support for MiniMax M3 block-sparse DSA attention.

Under context parallelism the sequence is sharded across CP ranks with a load-balanced layout (PyTorch’s causal CP splits the sequence into 2 * cp_size chunks and assigns rank r the pair {r, 2*cp_size-1-r}), so a rank’s local positions are not a contiguous global span. The M3 lightning indexer builds its block-sparse mask from index q/k over the global causal sequence, so a CP-aware sparse layer must gather the indexer inputs from every rank and reorder them into global token order before selecting blocks.

This module holds the reorder primitives shared by the CP-aware attention. The reorder math (order_by_positions / restore_by_positions) is factored out as pure tensor functions so the load-balance inverse — a silent-failure trap: a wrong inverse trains without shape errors but never converges — is unit-testable on CPU without a process group.

Module Contents

Classes

NameDescription
MiniMaxM3CPSparseAttentionContext-parallel-aware drop-in for a MiniMax M3 sparse-attention layer.
_AllGatherConcatFnAll-gather + concat with an autograd-safe backward.

Functions

NameDescription
_all_gather_concat_nogradPlain (non-differentiable) all-gather + concat along dim.
_get_compiled_flex_attention-
cp_document_idsPer-token document id from packed position ids (reset to 0 per document).
cp_load_balanced_global_slotsGlobal token-slot indices for PyTorch’s causal context-parallel load balancing.
order_by_positionsReorder a CP-gathered tensor from load-balanced order into global token order.
restore_by_positionsSelect rows of a global-ordered tensor back into an arbitrary (local) position order.

Data

_COMPILED_FLEX_ATTENTION

API

class nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.MiniMaxM3CPSparseAttention(
config: typing.Any,
backend: typing.Any,
is_sparse_attention_layer: bool = True
)

Bases: MiniMaxM3Attention

Context-parallel-aware drop-in for a MiniMax M3 sparse-attention layer.

Inherits every parameter and the eager forward from MiniMaxM3Attention. The only addition is _cp_mesh, installed post-FSDP via :meth:setup_cp_attention (called by the MoE parallelizer’s apply_cp). When CP is off (_cp_mesh is None / size 1) it delegates to the parent’s eager sparse forward, so non-CP runs are unaffected.

Under CP (cp_size > 1) the sequence is sharded across ranks, so the DSA block selection — which is causal over the global sequence — cannot be built from a rank’s local shard. This forward instead:

  1. projects q/k/v + indexer q/k locally and applies QK-norm + RoPE locally (freqs_cis already encodes each token’s global position, so phases stay correct after gathering);
  2. all-gathers k/v (autograd-safe) and the indexer key + token positions across the CP group, then reorders them into global token order (load-balanced CP sharding is non-contiguous — see :func:order_by_positions);
  3. selects the top-k key blocks for the local queries against the global key sequence (:func:select_sparse_blocks);
  4. attends with FlexAttention over a BlockMask that encodes the block selection + token-level causal, with the local queries against the full gathered K/V (enable_gqa=True). FlexAttention has a real backward, so the gathered K/V gradients flow back to the local shards.

Dense layers (0-2) are untouched; they use the standard DTensor-SDPA CP path.

nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.MiniMaxM3CPSparseAttention._cp_forward(
x: torch.Tensor,
freqs_cis: torch.Tensor,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.MiniMaxM3CPSparseAttention._flex_sparse_attention(
q: torch.Tensor,
k_global: torch.Tensor,
v_global: torch.Tensor,
block_sel: torch.Tensor,
q_positions: torch.Tensor,
key_valid: torch.Tensor | None = None,
doc_global: torch.Tensor | None = None,
q_doc: torch.Tensor | None = None
) -> torch.Tensor
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.MiniMaxM3CPSparseAttention.forward(
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.MiniMaxM3CPSparseAttention.setup_cp_attention(
cp_mesh: typing.Any
) -> None

Install the CP submesh consumed by :meth:_cp_forward (model-owned CP).

Called post-FSDP by the MoE parallelizer’s apply_cp for each sparse layer. Routing M3 through this hook — rather than having apply_cp set _cp_mesh directly — keeps it on the same model-owned CP path as the other custom-attention models (Gemma4, DeepSeek-V4).

class nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn._AllGatherConcatFn()

Bases: Function

All-gather + concat with an autograd-safe backward.

Forward concatenates equal-sized local shards from all CP ranks along dim. Backward all-reduces the concatenated gradient and slices out this rank’s shard. Mirrors qwen3_5_moe/cp_linear_attn.py’s helper.

nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn._AllGatherConcatFn.backward(
ctx,
grad_output: torch.Tensor
)
staticmethod
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn._AllGatherConcatFn.forward(
ctx,
local_tensor: torch.Tensor,
group: 'dist.ProcessGroup',
dim: int
)
staticmethod
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn._all_gather_concat_nograd(
tensor: torch.Tensor,
group: 'dist.ProcessGroup',
dim: int
) -> torch.Tensor

Plain (non-differentiable) all-gather + concat along dim.

nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn._get_compiled_flex_attention()
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.cp_document_ids(
positions: torch.Tensor
) -> torch.Tensor

Per-token document id from packed position ids (reset to 0 per document).

doc_id = cumsum(positions == 0) - 1 along the sequence dim: a 0-based id that increments at every position-0 (document start). A single sequence -> all zeros (so a same-document mask is all-True, a no-op). A trailing cp-pad (also position 0) opens a spurious extra document, but pad keys/queries are excluded by causality / the padding mask, so it is harmless.

Parameters:

positions
torch.Tensor

[B, T] long global-ordered position ids.

Returns: torch.Tensor

[B, T] long document ids.

nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.cp_load_balanced_global_slots(
cp_size: int,
t_local: int,
device: torch.device,
rank: int | None = None
) -> torch.Tensor

Global token-slot indices for PyTorch’s causal context-parallel load balancing.

Causal CP splits the (cp-padded) sequence into 2 * cp_size equal chunks and assigns rank r the pair {r, 2*cp_size-1-r} (concatenated in that order), so the local length is 2 * chunk. This reconstructs each local slot’s global index structurally — independent of position_ids values — which is robust to cp-padding (pad slots land at the global tail, where causality excludes them) and to the indexer’s pad position_id fill.

Parameters:

cp_size
int

context-parallel size.

t_local
int

local (per-rank) sequence length; must be even.

rank
int | NoneDefaults to None

if given, return the [t_local] slots for that CP rank; otherwise return the [cp_size * t_local] slots for the rank-major all-gathered concatenation (rank 0’s tokens, then rank 1’s, …).

Returns: torch.Tensor

1-D long tensor of global slot indices (a permutation of 0..T_global-1).

nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.order_by_positions(
gathered: torch.Tensor,
gathered_positions: torch.Tensor,
seq_dim: int
) -> tuple[torch.Tensor, torch.Tensor]

Reorder a CP-gathered tensor from load-balanced order into global token order.

Parameters:

gathered
torch.Tensor

tensor whose seq_dim concatenates every CP rank’s local shard in rank order (the output of an all-gather+concat).

gathered_positions
torch.Tensor

1-D global token positions aligned with gathered along seq_dim (gathered the same way). Must be a permutation of 0..S-1 where S = gathered.size(seq_dim).

seq_dim
int

the sequence dimension of gathered.

Returns: torch.Tensor

(global_tensor, sort_order) where global_tensor is gathered

Raises:

  • ValueError: if gathered_positions is not a dense permutation of 0..S-1.
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.restore_by_positions(
global_tensor: torch.Tensor,
target_positions: torch.Tensor,
seq_dim: int
) -> torch.Tensor

Select rows of a global-ordered tensor back into an arbitrary (local) position order.

Inverse companion to :func:order_by_positions. Given a tensor indexed by global position along seq_dim (position p at index p), return the slice in target_positions order — e.g. this rank’s load-balanced local positions, recovering the CP-sharded layout.

Parameters:

global_tensor
torch.Tensor

tensor indexed by global position along seq_dim.

target_positions
torch.Tensor

1-D positions to select, in the desired output order.

seq_dim
int

the sequence dimension of global_tensor.

Returns: torch.Tensor

global_tensor gathered along seq_dim at target_positions.

nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn._COMPILED_FLEX_ATTENTION = None