nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn
Context-parallel support for MiniMax M3 block-sparse DSA attention.
Under context parallelism the sequence is sharded across CP ranks with a
load-balanced layout (PyTorch’s causal CP splits the sequence into
2 * cp_size chunks and assigns rank r the pair {r, 2*cp_size-1-r}),
so a rank’s local positions are not a contiguous global span. The M3
lightning indexer builds its block-sparse mask from index q/k over the global
causal sequence, so a CP-aware sparse layer must gather the indexer inputs from
every rank and reorder them into global token order before selecting blocks.
This module holds the reorder primitives shared by the CP-aware attention. The
reorder math (order_by_positions / restore_by_positions) is factored out
as pure tensor functions so the load-balance inverse — a silent-failure trap: a
wrong inverse trains without shape errors but never converges — is unit-testable
on CPU without a process group.
Module Contents
Classes
Functions
Data
API
Bases: MiniMaxM3Attention
Context-parallel-aware drop-in for a MiniMax M3 sparse-attention layer.
Inherits every parameter and the eager forward from MiniMaxM3Attention.
The only addition is _cp_mesh, installed post-FSDP via
:meth:setup_cp_attention (called by the MoE parallelizer’s apply_cp).
When CP is off (_cp_mesh is None / size 1) it delegates to the parent’s
eager sparse forward, so non-CP runs are unaffected.
Under CP (cp_size > 1) the sequence is sharded across ranks, so the DSA
block selection — which is causal over the global sequence — cannot be
built from a rank’s local shard. This forward instead:
- projects q/k/v + indexer q/k locally and applies QK-norm + RoPE locally
(
freqs_cisalready encodes each token’s global position, so phases stay correct after gathering); - all-gathers k/v (autograd-safe) and the indexer key + token positions
across the CP group, then reorders them into global token order
(load-balanced CP sharding is non-contiguous — see
:func:
order_by_positions); - selects the top-k key blocks for the local queries against the global
key sequence (:func:
select_sparse_blocks); - attends with FlexAttention over a
BlockMaskthat encodes the block selection + token-level causal, with the local queries against the full gathered K/V (enable_gqa=True). FlexAttention has a real backward, so the gathered K/V gradients flow back to the local shards.
Dense layers (0-2) are untouched; they use the standard DTensor-SDPA CP path.
Install the CP submesh consumed by :meth:_cp_forward (model-owned CP).
Called post-FSDP by the MoE parallelizer’s apply_cp for each sparse
layer. Routing M3 through this hook — rather than having apply_cp set
_cp_mesh directly — keeps it on the same model-owned CP path as the
other custom-attention models (Gemma4, DeepSeek-V4).
Bases: Function
All-gather + concat with an autograd-safe backward.
Forward concatenates equal-sized local shards from all CP ranks along
dim. Backward all-reduces the concatenated gradient and slices out this
rank’s shard. Mirrors qwen3_5_moe/cp_linear_attn.py’s helper.
Plain (non-differentiable) all-gather + concat along dim.
Per-token document id from packed position ids (reset to 0 per document).
doc_id = cumsum(positions == 0) - 1 along the sequence dim: a 0-based id
that increments at every position-0 (document start). A single sequence -> all
zeros (so a same-document mask is all-True, a no-op). A trailing cp-pad (also
position 0) opens a spurious extra document, but pad keys/queries are excluded
by causality / the padding mask, so it is harmless.
Parameters:
[B, T] long global-ordered position ids.
Returns: torch.Tensor
[B, T] long document ids.
Global token-slot indices for PyTorch’s causal context-parallel load balancing.
Causal CP splits the (cp-padded) sequence into 2 * cp_size equal chunks and
assigns rank r the pair {r, 2*cp_size-1-r} (concatenated in that order),
so the local length is 2 * chunk. This reconstructs each local slot’s global
index structurally — independent of position_ids values — which is robust
to cp-padding (pad slots land at the global tail, where causality excludes them)
and to the indexer’s pad position_id fill.
Parameters:
context-parallel size.
local (per-rank) sequence length; must be even.
if given, return the [t_local] slots for that CP rank; otherwise
return the [cp_size * t_local] slots for the rank-major all-gathered
concatenation (rank 0’s tokens, then rank 1’s, …).
Returns: torch.Tensor
1-D long tensor of global slot indices (a permutation of 0..T_global-1).
Reorder a CP-gathered tensor from load-balanced order into global token order.
Parameters:
tensor whose seq_dim concatenates every CP rank’s local shard
in rank order (the output of an all-gather+concat).
1-D global token positions aligned with gathered
along seq_dim (gathered the same way). Must be a permutation of
0..S-1 where S = gathered.size(seq_dim).
the sequence dimension of gathered.
Returns: torch.Tensor
(global_tensor, sort_order) where global_tensor is gathered
Raises:
ValueError: ifgathered_positionsis not a dense permutation of 0..S-1.
Select rows of a global-ordered tensor back into an arbitrary (local) position order.
Inverse companion to :func:order_by_positions. Given a tensor indexed by
global position along seq_dim (position p at index p), return the
slice in target_positions order — e.g. this rank’s load-balanced local
positions, recovering the CP-sharded layout.
Parameters:
tensor indexed by global position along seq_dim.
1-D positions to select, in the desired output order.
the sequence dimension of global_tensor.
Returns: torch.Tensor
global_tensor gathered along seq_dim at target_positions.