> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn

Context-parallel support for MiniMax M3 block-sparse DSA attention.

Under context parallelism the sequence is sharded across CP ranks with a
*load-balanced* layout (PyTorch's causal CP splits the sequence into
`2 * cp_size` chunks and assigns rank `r` the pair `&#123;r, 2*cp_size-1-r&#125;`),
so a rank's local positions are **not** a contiguous global span. The M3
lightning indexer builds its block-sparse mask from index q/k over the *global*
causal sequence, so a CP-aware sparse layer must gather the indexer inputs from
every rank and reorder them into global token order before selecting blocks.

This module holds the reorder primitives shared by the CP-aware attention. The
reorder math (`order_by_positions` / `restore_by_positions`) is factored out
as pure tensor functions so the load-balance inverse -- a silent-failure trap: a
wrong inverse trains without shape errors but never converges -- is unit-testable
on CPU without a process group.

## Module Contents

### Classes

| Name                                                                                                                      | Description                                                             |
| ------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------- |
| [`MiniMaxM3CPSparseAttention`](#nemo_automodel-components-models-minimax_m3_vl-cp_sparse_attn-MiniMaxM3CPSparseAttention) | Context-parallel-aware drop-in for a MiniMax M3 sparse-attention layer. |
| [`_AllGatherConcatFn`](#nemo_automodel-components-models-minimax_m3_vl-cp_sparse_attn-_AllGatherConcatFn)                 | All-gather + concat with an autograd-safe backward.                     |

### Functions

| Name                                                                                                                            | Description                                                                           |
| ------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------- |
| [`_all_gather_concat_nograd`](#nemo_automodel-components-models-minimax_m3_vl-cp_sparse_attn-_all_gather_concat_nograd)         | Plain (non-differentiable) all-gather + concat along `dim`.                           |
| [`_get_compiled_flex_attention`](#nemo_automodel-components-models-minimax_m3_vl-cp_sparse_attn-_get_compiled_flex_attention)   | -                                                                                     |
| [`cp_document_ids`](#nemo_automodel-components-models-minimax_m3_vl-cp_sparse_attn-cp_document_ids)                             | Per-token document id from packed position ids (reset to 0 per document).             |
| [`cp_load_balanced_global_slots`](#nemo_automodel-components-models-minimax_m3_vl-cp_sparse_attn-cp_load_balanced_global_slots) | Global token-slot indices for PyTorch's causal context-parallel load balancing.       |
| [`order_by_positions`](#nemo_automodel-components-models-minimax_m3_vl-cp_sparse_attn-order_by_positions)                       | Reorder a CP-gathered tensor from load-balanced order into global token order.        |
| [`restore_by_positions`](#nemo_automodel-components-models-minimax_m3_vl-cp_sparse_attn-restore_by_positions)                   | Select rows of a global-ordered tensor back into an arbitrary (local) position order. |

### Data

[`_COMPILED_FLEX_ATTENTION`](#nemo_automodel-components-models-minimax_m3_vl-cp_sparse_attn-_COMPILED_FLEX_ATTENTION)

### API

```python
class nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.MiniMaxM3CPSparseAttention(
    config: typing.Any,
    backend: typing.Any,
    is_sparse_attention_layer: bool = True
)
```

**Bases:** [MiniMaxM3Attention](/nemo-automodel/nemo_automodel/components/models/minimax_m3_vl/layers#nemo_automodel-components-models-minimax_m3_vl-layers-MiniMaxM3Attention)

Context-parallel-aware drop-in for a MiniMax M3 sparse-attention layer.

Inherits every parameter and the eager forward from `MiniMaxM3Attention`.
The only addition is `_cp_mesh`, installed post-FSDP via
:meth:`setup_cp_attention` (called by the MoE parallelizer's `apply_cp`).
When CP is off (`_cp_mesh` is None / size 1) it delegates to the parent's
eager sparse forward, so non-CP runs are unaffected.

Under CP (`cp_size &gt; 1`) the sequence is sharded across ranks, so the DSA
block selection -- which is causal over the *global* sequence -- cannot be
built from a rank's local shard. This forward instead:

1. projects q/k/v + indexer q/k locally and applies QK-norm + RoPE locally
   (`freqs_cis` already encodes each token's global position, so phases
   stay correct after gathering);
2. all-gathers k/v (autograd-safe) and the indexer key + token positions
   across the CP group, then reorders them into global token order
   (load-balanced CP sharding is non-contiguous -- see
   :func:`order_by_positions`);
3. selects the top-k key blocks for the *local* queries against the global
   key sequence (:func:`select_sparse_blocks`);
4. attends with FlexAttention over a `BlockMask` that encodes the block
   selection + token-level causal, with the local queries against the full
   gathered K/V (`enable_gqa=True`). FlexAttention has a real backward,
   so the gathered K/V gradients flow back to the local shards.

Dense layers (0-2) are untouched; they use the standard DTensor-SDPA CP path.

```python
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.MiniMaxM3CPSparseAttention._cp_forward(
    x: torch.Tensor,
    freqs_cis: torch.Tensor,
    attn_kwargs: typing.Any = {}
) -> torch.Tensor
```

```python
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.MiniMaxM3CPSparseAttention._flex_sparse_attention(
    q: torch.Tensor,
    k_global: torch.Tensor,
    v_global: torch.Tensor,
    block_sel: torch.Tensor,
    q_positions: torch.Tensor,
    key_valid: torch.Tensor | None = None,
    doc_global: torch.Tensor | None = None,
    q_doc: torch.Tensor | None = None
) -> torch.Tensor
```

```python
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.MiniMaxM3CPSparseAttention.forward(
    x: torch.Tensor,
    freqs_cis: torch.Tensor,
    attention_mask: torch.Tensor | None = None,
    attn_kwargs: typing.Any = {}
) -> torch.Tensor
```

```python
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.MiniMaxM3CPSparseAttention.setup_cp_attention(
    cp_mesh: typing.Any
) -> None
```

Install the CP submesh consumed by :meth:`_cp_forward` (model-owned CP).

Called post-FSDP by the MoE parallelizer's `apply_cp` for each sparse
layer. Routing M3 through this hook -- rather than having `apply_cp` set
`_cp_mesh` directly -- keeps it on the same model-owned CP path as the
other custom-attention models (Gemma4, DeepSeek-V4).

```python
class nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn._AllGatherConcatFn()
```

**Bases:** `Function`

All-gather + concat with an autograd-safe backward.

Forward concatenates equal-sized local shards from all CP ranks along
`dim`. Backward all-reduces the concatenated gradient and slices out this
rank's shard. Mirrors `qwen3_5_moe/cp_linear_attn.py`'s helper.

```python
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn._AllGatherConcatFn.backward(
    ctx,
    grad_output: torch.Tensor
)
```

staticmethod

```python
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn._AllGatherConcatFn.forward(
    ctx,
    local_tensor: torch.Tensor,
    group: 'dist.ProcessGroup',
    dim: int
)
```

staticmethod

```python
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn._all_gather_concat_nograd(
    tensor: torch.Tensor,
    group: 'dist.ProcessGroup',
    dim: int
) -> torch.Tensor
```

Plain (non-differentiable) all-gather + concat along `dim`.

```python
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn._get_compiled_flex_attention()
```

```python
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.cp_document_ids(
    positions: torch.Tensor
) -> torch.Tensor
```

Per-token document id from packed position ids (reset to 0 per document).

`doc_id = cumsum(positions == 0) - 1` along the sequence dim: a 0-based id
that increments at every position-0 (document start). A single sequence -> all
zeros (so a same-document mask is all-True, a no-op). A trailing cp-pad (also
position 0) opens a spurious extra document, but pad keys/queries are excluded
by causality / the padding mask, so it is harmless.

**Parameters:**

`[B, T]` long global-ordered position ids.

**Returns:** `torch.Tensor`

`[B, T]` long document ids.

```python
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.cp_load_balanced_global_slots(
    cp_size: int,
    t_local: int,
    device: torch.device,
    rank: int | None = None
) -> torch.Tensor
```

Global token-slot indices for PyTorch's causal context-parallel load balancing.

Causal CP splits the (cp-padded) sequence into `2 * cp_size` equal chunks and
assigns rank `r` the pair `&#123;r, 2*cp_size-1-r&#125;` (concatenated in that order),
so the local length is `2 * chunk`. This reconstructs each local slot's global
index *structurally* -- independent of `position_ids` values -- which is robust
to cp-padding (pad slots land at the global tail, where causality excludes them)
and to the indexer's pad `position_id` fill.

**Parameters:**

context-parallel size.

local (per-rank) sequence length; must be even.

if given, return the `[t_local]` slots for that CP rank; otherwise
return the `[cp_size * t_local]` slots for the rank-major all-gathered
concatenation (rank 0's tokens, then rank 1's, ...).

**Returns:** `torch.Tensor`

1-D long tensor of global slot indices (a permutation of `0..T_global-1`).

```python
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.order_by_positions(
    gathered: torch.Tensor,
    gathered_positions: torch.Tensor,
    seq_dim: int
) -> tuple[torch.Tensor, torch.Tensor]
```

Reorder a CP-gathered tensor from load-balanced order into global token order.

**Parameters:**

tensor whose `seq_dim` concatenates every CP rank's local shard
in rank order (the output of an all-gather+concat).

1-D global token positions aligned with `gathered`
along `seq_dim` (gathered the same way). Must be a permutation of
`0..S-1` where `S = gathered.size(seq_dim)`.

the sequence dimension of `gathered`.

**Returns:** `torch.Tensor`

`(global_tensor, sort_order)` where `global_tensor` is `gathered`

**Raises:**

* `ValueError`: if `gathered_positions` is not a dense permutation of 0..S-1.

```python
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn.restore_by_positions(
    global_tensor: torch.Tensor,
    target_positions: torch.Tensor,
    seq_dim: int
) -> torch.Tensor
```

Select rows of a global-ordered tensor back into an arbitrary (local) position order.

Inverse companion to :func:`order_by_positions`. Given a tensor indexed by
global position along `seq_dim` (position `p` at index `p`), return the
slice in `target_positions` order -- e.g. this rank's load-balanced local
positions, recovering the CP-sharded layout.

**Parameters:**

tensor indexed by global position along `seq_dim`.

1-D positions to select, in the desired output order.

the sequence dimension of `global_tensor`.

**Returns:** `torch.Tensor`

`global_tensor` gathered along `seq_dim` at `target_positions`.

```python
nemo_automodel.components.models.minimax_m3_vl.cp_sparse_attn._COMPILED_FLEX_ATTENTION = None
```