> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.distributed.mamba_cp

Context parallelism for Mamba/SSM layers using a hidden-parallel strategy.

Instead of splitting the sequence across CP ranks (as attention CP does), this module
uses an all-to-all redistribution so that each CP rank processes the *full* sequence
but only a *subset* of heads (d\_inner / cp\_size).  The data flow is::

\[B, L\_local, D]  -->  all-to-all  -->  \[B, L\_global, D/cp]
\-->  conv1d + SSM kernel  -->
\[B, L\_global, D/cp]  -->  all-to-all  -->  \[B, L\_local, D]

This module is intentionally **not** a subclass of `nn.Module` because it owns
no trainable parameters.  It holds *references* to the Mamba mixer's parameters
and slices them in the forward path so that gradients flow back to the full
(unsliced) parameters.

## Module Contents

### Classes

| Name                                                                                           | Description                                                    |
| ---------------------------------------------------------------------------------------------- | -------------------------------------------------------------- |
| [`MambaContextParallel`](#nemo_automodel-components-distributed-mamba_cp-MambaContextParallel) | Hidden-parallel context parallelism for a Mamba2 mixer layer.  |
| [`_AllToAll`](#nemo_automodel-components-distributed-mamba_cp-_AllToAll)                       | Autograd wrapper around `torch.distributed.all_to_all_single`. |

### Functions

| Name                                                                                                               | Description                                                                |
| ------------------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------- |
| [`_all_to_all`](#nemo_automodel-components-distributed-mamba_cp-_all_to_all)                                       | Functional entry-point for the autograd-aware all-to-all.                  |
| [`_all_to_all_cp2hp`](#nemo_automodel-components-distributed-mamba_cp-_all_to_all_cp2hp)                           | Transform from sequence-sharded to hidden-sharded layout (batch-first).    |
| [`_all_to_all_hp2cp`](#nemo_automodel-components-distributed-mamba_cp-_all_to_all_hp2cp)                           | Transform from hidden-sharded to sequence-sharded layout (batch-first).    |
| [`_deinterleave_packed_seqs`](#nemo_automodel-components-distributed-mamba_cp-_deinterleave_packed_seqs)           | Rearrange tokens from rank-major to sequence-major order after all-to-all. |
| [`_redo_attention_load_balancing`](#nemo_automodel-components-distributed-mamba_cp-_redo_attention_load_balancing) | Reorder from sequential back to DualChunkSwap for attention.               |
| [`_reinterleave_packed_seqs`](#nemo_automodel-components-distributed-mamba_cp-_reinterleave_packed_seqs)           | Inverse of :func:`_deinterleave_packed_seqs`.                              |
| [`_reorder_chunks`](#nemo_automodel-components-distributed-mamba_cp-_reorder_chunks)                               | Reorder equal-sized chunks of a tensor according to *order*.               |
| [`_undo_attention_load_balancing`](#nemo_automodel-components-distributed-mamba_cp-_undo_attention_load_balancing) | Reorder from DualChunkSwap to sequential for SSM processing.               |

### API

```python
class nemo_automodel.components.distributed.mamba_cp.MambaContextParallel(
    cp_group: torch.distributed.ProcessGroup,
    num_heads: int,
    head_dim: int,
    n_groups: int,
    d_state: int,
    mixer: torch.nn.Module
)
```

Hidden-parallel context parallelism for a Mamba2 mixer layer.

This class does **not** own trainable parameters.  It stores a *reference*
to the mixer module and accesses its parameters (conv1d, dt\_bias, A\_log, D)
on the fly so that gradients propagate to the original (full) parameters
and FSDP-managed DTensor replacements are picked up correctly.

DualChunkSwap reordering is always undone before the SSM kernel and redone
after, because both TE CP (p2p) and PyTorch's `context_parallel(allgather)`
reorder sequence chunks for load balancing.

**Parameters:**

Context-parallel process group.

Total number of SSM heads (before any parallelism).

Dimension per head.

Number of SSM groups (for grouped B/C states).

SSM state dimension.

Reference to the Mamba mixer module (owns conv1d, dt\_bias, A\_log, D).

```python
nemo_automodel.components.distributed.mamba_cp.MambaContextParallel._repeat_group_state(
    state: torch.Tensor
) -> torch.Tensor
```

Repeat group states for CP ranks when n\_groups \< cp\_size.

`[B, L, n_groups * d_state]` -> `[B, L, n_groups * repeat * d_state]`
Also supports THD 2D input `[T, n_groups * d_state]`.

```python
nemo_automodel.components.distributed.mamba_cp.MambaContextParallel._slice_conv_param(
    param: torch.Tensor
) -> torch.Tensor
```

Slice a conv1d parameter (weight or bias) along its channel dimension.

Parameter slicing is done in the forward path so that gradients
backpropagate to the original (full) parameters.

```python
nemo_automodel.components.distributed.mamba_cp.MambaContextParallel._slice_vector_param(
    param: torch.Tensor
) -> torch.Tensor
```

Slice a per-head vector parameter for the current CP rank.

```python
nemo_automodel.components.distributed.mamba_cp.MambaContextParallel.get_A_log() -> torch.Tensor
```

Slice `A_log` for the current CP rank.

```python
nemo_automodel.components.distributed.mamba_cp.MambaContextParallel.get_D() -> torch.Tensor
```

Slice `D` for the current CP rank.

```python
nemo_automodel.components.distributed.mamba_cp.MambaContextParallel.get_conv1d_bias() -> torch.Tensor
```

Slice `conv1d.bias` for the current CP rank.

Bias shape: `[conv_dim]`.  Returns `[conv_dim_local]`.

```python
nemo_automodel.components.distributed.mamba_cp.MambaContextParallel.get_conv1d_weight() -> torch.Tensor
```

Slice `conv1d.weight` for the current CP rank.

Weight shape: `[conv_dim, 1, kernel_size]` where
`conv_dim = d_inner + 2 * n_groups * d_state`.
Returns `[conv_dim_local, kernel_size]` (squeezed for causal\_conv1d kernel).

```python
nemo_automodel.components.distributed.mamba_cp.MambaContextParallel.get_dt_bias() -> torch.Tensor
```

Slice `dt_bias` for the current CP rank.

```python
nemo_automodel.components.distributed.mamba_cp.MambaContextParallel.post_conv_ssm(
    output: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None
) -> torch.Tensor
```

Redistribute SSM output from hidden-sharded back to sequence-sharded layout.

```python
nemo_automodel.components.distributed.mamba_cp.MambaContextParallel.pre_conv_ssm(
    projected_states: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None
) -> torch.Tensor
```

Redistribute from sequence-sharded to hidden-sharded layout, undoing DualChunkSwap.

```python
class nemo_automodel.components.distributed.mamba_cp._AllToAll()
```

**Bases:** `Function`

Autograd wrapper around `torch.distributed.all_to_all_single`.

For equal-sized splits the all-to-all operation is its own inverse,
so the backward pass is simply another all-to-all on the same group.

```python
nemo_automodel.components.distributed.mamba_cp._AllToAll.backward(
    ctx,
    grad_output: torch.Tensor
)
```

staticmethod

```python
nemo_automodel.components.distributed.mamba_cp._AllToAll.forward(
    ctx,
    input_: torch.Tensor,
    group: torch.distributed.ProcessGroup
) -> torch.Tensor
```

staticmethod

```python
nemo_automodel.components.distributed.mamba_cp._all_to_all(
    input_: torch.Tensor,
    group: torch.distributed.ProcessGroup
) -> torch.Tensor
```

Functional entry-point for the autograd-aware all-to-all.

```python
nemo_automodel.components.distributed.mamba_cp._all_to_all_cp2hp(
    input_: torch.Tensor,
    cp_group: torch.distributed.ProcessGroup,
    batch_size: int
) -> torch.Tensor
```

Transform from sequence-sharded to hidden-sharded layout (batch-first).

**Parameters:**

Tensor of shape `[B, L_local, H]` (BSHD) or `[T, H]` (THD)
where H is the full hidden dimension on this rank.

Context-parallel process group.

Batch size `B` (needed to recover dimensions after reshape).

**Returns:** `torch.Tensor`

Tensor of shape `[B, L_global, H / cp_size]` (BSHD) or `[T, H / cp_size]` (THD).

```python
nemo_automodel.components.distributed.mamba_cp._all_to_all_hp2cp(
    input_: torch.Tensor,
    cp_group: torch.distributed.ProcessGroup,
    batch_size: int
) -> torch.Tensor
```

Transform from hidden-sharded to sequence-sharded layout (batch-first).

This is the inverse of :func:`_all_to_all_cp2hp`.

**Parameters:**

Tensor of shape `[B, L_global, H_local]` (BSHD) or `[T, H_local]` (THD)
where `H_local = H / cp_size`.

Context-parallel process group.

Batch size `B`.

**Returns:** `torch.Tensor`

Tensor of shape `[B, L_local, H]` (BSHD) or `[T, H]` (THD)

```python
nemo_automodel.components.distributed.mamba_cp._deinterleave_packed_seqs(
    input_: torch.Tensor,
    cu_seqlens: torch.Tensor,
    cp_size: int
) -> torch.Tensor
```

Rearrange tokens from rank-major to sequence-major order after all-to-all.

After `_all_to_all_cp2hp` on packed 2-D data the token layout along
the sequence dimension is::

\[rank0\_seq0 | rank0\_seq1 | ... | rank1\_seq0 | rank1\_seq1 | ...]

This function rearranges to::

\[rank0\_seq0 | rank1\_seq0 | ... | rank0\_seq1 | rank1\_seq1 | ...]

so that each sequence's tokens are contiguous (required by the
`_undo_attention_load_balancing` reorder that follows).

**Parameters:**

2-D tensor `[T_global, H]`.

**Local** (pre-all-to-all) cumulative sequence lengths.

Context-parallel world size.

**Returns:** `torch.Tensor`

Rearranged 2-D tensor with the same shape.

```python
nemo_automodel.components.distributed.mamba_cp._redo_attention_load_balancing(
    input_: torch.Tensor,
    cp_size: int,
    cu_seqlens: torch.Tensor | None = None
) -> torch.Tensor
```

Reorder from sequential back to DualChunkSwap for attention.

Inverse of :func:`_undo_attention_load_balancing`.

```python
nemo_automodel.components.distributed.mamba_cp._reinterleave_packed_seqs(
    input_: torch.Tensor,
    cu_seqlens: torch.Tensor,
    cp_size: int
) -> torch.Tensor
```

Inverse of :func:`_deinterleave_packed_seqs`.

Rearranges from sequence-major back to rank-major order before the
inverse all-to-all in `post_conv_ssm`.

```python
nemo_automodel.components.distributed.mamba_cp._reorder_chunks(
    input_: torch.Tensor,
    order: list[int],
    cu_seqlens: torch.Tensor | None = None
) -> torch.Tensor
```

Reorder equal-sized chunks of a tensor according to *order*.

**Parameters:**

`[B, L, H]` (BSHD) or `[T, H]` (THD).

Permutation indices (length must equal the number of chunks).

If provided, reorder per-sequence on dim=0 (THD).

```python
nemo_automodel.components.distributed.mamba_cp._undo_attention_load_balancing(
    input_: torch.Tensor,
    cp_size: int,
    cu_seqlens: torch.Tensor | None = None
) -> torch.Tensor
```

Reorder from DualChunkSwap to sequential for SSM processing.