nemo_automodel.components.distributed.mamba_cp

View as Markdown

Context parallelism for Mamba/SSM layers using a hidden-parallel strategy.

Instead of splitting the sequence across CP ranks (as attention CP does), this module uses an all-to-all redistribution so that each CP rank processes the full sequence but only a subset of heads (d_inner / cp_size). The data flow is::

[B, L_local, D] —> all-to-all —> [B, L_global, D/cp] —> conv1d + SSM kernel —> [B, L_global, D/cp] —> all-to-all —> [B, L_local, D]

This module is intentionally not a subclass of nn.Module because it owns no trainable parameters. It holds references to the Mamba mixer’s parameters and slices them in the forward path so that gradients flow back to the full (unsliced) parameters.

Module Contents

Classes

NameDescription
MambaContextParallelHidden-parallel context parallelism for a Mamba2 mixer layer.
_AllToAllAutograd wrapper around torch.distributed.all_to_all_single.

Functions

NameDescription
_all_to_allFunctional entry-point for the autograd-aware all-to-all.
_all_to_all_cp2hpTransform from sequence-sharded to hidden-sharded layout (batch-first).
_all_to_all_hp2cpTransform from hidden-sharded to sequence-sharded layout (batch-first).
_deinterleave_packed_seqsRearrange tokens from rank-major to sequence-major order after all-to-all.
_redo_attention_load_balancingReorder from sequential back to DualChunkSwap for attention.
_reinterleave_packed_seqsInverse of :func:_deinterleave_packed_seqs.
_reorder_chunksReorder equal-sized chunks of a tensor according to order.
_undo_attention_load_balancingReorder from DualChunkSwap to sequential for SSM processing.

API

class nemo_automodel.components.distributed.mamba_cp.MambaContextParallel(
cp_group: torch.distributed.ProcessGroup,
num_heads: int,
head_dim: int,
n_groups: int,
d_state: int,
mixer: torch.nn.Module
)

Hidden-parallel context parallelism for a Mamba2 mixer layer.

This class does not own trainable parameters. It stores a reference to the mixer module and accesses its parameters (conv1d, dt_bias, A_log, D) on the fly so that gradients propagate to the original (full) parameters and FSDP-managed DTensor replacements are picked up correctly.

DualChunkSwap reordering is always undone before the SSM kernel and redone after, because both TE CP (p2p) and PyTorch’s context_parallel(allgather) reorder sequence chunks for load balancing.

Parameters:

cp_group
torch.distributed.ProcessGroup

Context-parallel process group.

num_heads
int

Total number of SSM heads (before any parallelism).

head_dim
int

Dimension per head.

n_groups
int

Number of SSM groups (for grouped B/C states).

d_state
int

SSM state dimension.

mixer
nn.Module

Reference to the Mamba mixer module (owns conv1d, dt_bias, A_log, D).

cp_rank
= cp_group.rank()
cp_size
= cp_group.size()
d_inner
= num_heads * head_dim
d_inner_local
= self.num_heads_local * head_dim
group_repeat_count
= self.cp_size // n_groups
n_groups_local
= 1
num_heads_local
= num_heads // self.cp_size
nemo_automodel.components.distributed.mamba_cp.MambaContextParallel._repeat_group_state(
state: torch.Tensor
) -> torch.Tensor

Repeat group states for CP ranks when n_groups < cp_size.

[B, L, n_groups * d_state] -> [B, L, n_groups * repeat * d_state] Also supports THD 2D input [T, n_groups * d_state].

nemo_automodel.components.distributed.mamba_cp.MambaContextParallel._slice_conv_param(
param: torch.Tensor
) -> torch.Tensor

Slice a conv1d parameter (weight or bias) along its channel dimension.

Parameter slicing is done in the forward path so that gradients backpropagate to the original (full) parameters.

nemo_automodel.components.distributed.mamba_cp.MambaContextParallel._slice_vector_param(
param: torch.Tensor
) -> torch.Tensor

Slice a per-head vector parameter for the current CP rank.

nemo_automodel.components.distributed.mamba_cp.MambaContextParallel.get_A_log() -> torch.Tensor

Slice A_log for the current CP rank.

nemo_automodel.components.distributed.mamba_cp.MambaContextParallel.get_D() -> torch.Tensor

Slice D for the current CP rank.

nemo_automodel.components.distributed.mamba_cp.MambaContextParallel.get_conv1d_bias() -> torch.Tensor

Slice conv1d.bias for the current CP rank.

Bias shape: [conv_dim]. Returns [conv_dim_local].

nemo_automodel.components.distributed.mamba_cp.MambaContextParallel.get_conv1d_weight() -> torch.Tensor

Slice conv1d.weight for the current CP rank.

Weight shape: [conv_dim, 1, kernel_size] where conv_dim = d_inner + 2 * n_groups * d_state. Returns [conv_dim_local, kernel_size] (squeezed for causal_conv1d kernel).

nemo_automodel.components.distributed.mamba_cp.MambaContextParallel.get_dt_bias() -> torch.Tensor

Slice dt_bias for the current CP rank.

nemo_automodel.components.distributed.mamba_cp.MambaContextParallel.post_conv_ssm(
output: torch.Tensor,
cu_seqlens: torch.Tensor | None = None
) -> torch.Tensor

Redistribute SSM output from hidden-sharded back to sequence-sharded layout.

nemo_automodel.components.distributed.mamba_cp.MambaContextParallel.pre_conv_ssm(
projected_states: torch.Tensor,
cu_seqlens: torch.Tensor | None = None
) -> torch.Tensor

Redistribute from sequence-sharded to hidden-sharded layout, undoing DualChunkSwap.

class nemo_automodel.components.distributed.mamba_cp._AllToAll()

Bases: Function

Autograd wrapper around torch.distributed.all_to_all_single.

For equal-sized splits the all-to-all operation is its own inverse, so the backward pass is simply another all-to-all on the same group.

nemo_automodel.components.distributed.mamba_cp._AllToAll.backward(
ctx,
grad_output: torch.Tensor
)
staticmethod
nemo_automodel.components.distributed.mamba_cp._AllToAll.forward(
ctx,
input_: torch.Tensor,
group: torch.distributed.ProcessGroup
) -> torch.Tensor
staticmethod
nemo_automodel.components.distributed.mamba_cp._all_to_all(
input_: torch.Tensor,
group: torch.distributed.ProcessGroup
) -> torch.Tensor

Functional entry-point for the autograd-aware all-to-all.

nemo_automodel.components.distributed.mamba_cp._all_to_all_cp2hp(
input_: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
batch_size: int
) -> torch.Tensor

Transform from sequence-sharded to hidden-sharded layout (batch-first).

Parameters:

input_
torch.Tensor

Tensor of shape [B, L_local, H] (BSHD) or [T, H] (THD) where H is the full hidden dimension on this rank.

cp_group
torch.distributed.ProcessGroup

Context-parallel process group.

batch_size
int

Batch size B (needed to recover dimensions after reshape).

Returns: torch.Tensor

Tensor of shape [B, L_global, H / cp_size] (BSHD) or [T, H / cp_size] (THD).

nemo_automodel.components.distributed.mamba_cp._all_to_all_hp2cp(
input_: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
batch_size: int
) -> torch.Tensor

Transform from hidden-sharded to sequence-sharded layout (batch-first).

This is the inverse of :func:_all_to_all_cp2hp.

Parameters:

input_
torch.Tensor

Tensor of shape [B, L_global, H_local] (BSHD) or [T, H_local] (THD) where H_local = H / cp_size.

cp_group
torch.distributed.ProcessGroup

Context-parallel process group.

batch_size
int

Batch size B.

Returns: torch.Tensor

Tensor of shape [B, L_local, H] (BSHD) or [T, H] (THD)

nemo_automodel.components.distributed.mamba_cp._deinterleave_packed_seqs(
input_: torch.Tensor,
cu_seqlens: torch.Tensor,
cp_size: int
) -> torch.Tensor

Rearrange tokens from rank-major to sequence-major order after all-to-all.

After _all_to_all_cp2hp on packed 2-D data the token layout along the sequence dimension is::

[rank0_seq0 | rank0_seq1 | … | rank1_seq0 | rank1_seq1 | …]

This function rearranges to::

[rank0_seq0 | rank1_seq0 | … | rank0_seq1 | rank1_seq1 | …]

so that each sequence’s tokens are contiguous (required by the _undo_attention_load_balancing reorder that follows).

Parameters:

input_
torch.Tensor

2-D tensor [T_global, H].

cu_seqlens
torch.Tensor

Local (pre-all-to-all) cumulative sequence lengths.

cp_size
int

Context-parallel world size.

Returns: torch.Tensor

Rearranged 2-D tensor with the same shape.

nemo_automodel.components.distributed.mamba_cp._redo_attention_load_balancing(
input_: torch.Tensor,
cp_size: int,
cu_seqlens: torch.Tensor | None = None
) -> torch.Tensor

Reorder from sequential back to DualChunkSwap for attention.

Inverse of :func:_undo_attention_load_balancing.

nemo_automodel.components.distributed.mamba_cp._reinterleave_packed_seqs(
input_: torch.Tensor,
cu_seqlens: torch.Tensor,
cp_size: int
) -> torch.Tensor

Inverse of :func:_deinterleave_packed_seqs.

Rearranges from sequence-major back to rank-major order before the inverse all-to-all in post_conv_ssm.

nemo_automodel.components.distributed.mamba_cp._reorder_chunks(
input_: torch.Tensor,
order: list[int],
cu_seqlens: torch.Tensor | None = None
) -> torch.Tensor

Reorder equal-sized chunks of a tensor according to order.

Parameters:

input_
torch.Tensor

[B, L, H] (BSHD) or [T, H] (THD).

order
list[int]

Permutation indices (length must equal the number of chunks).

cu_seqlens
torch.Tensor | NoneDefaults to None

If provided, reorder per-sequence on dim=0 (THD).

nemo_automodel.components.distributed.mamba_cp._undo_attention_load_balancing(
input_: torch.Tensor,
cp_size: int,
cu_seqlens: torch.Tensor | None = None
) -> torch.Tensor

Reorder from DualChunkSwap to sequential for SSM processing.