nemo_automodel.components.distributed.mamba_cp
nemo_automodel.components.distributed.mamba_cp
Context parallelism for Mamba/SSM layers using a hidden-parallel strategy.
Instead of splitting the sequence across CP ranks (as attention CP does), this module uses an all-to-all redistribution so that each CP rank processes the full sequence but only a subset of heads (d_inner / cp_size). The data flow is::
[B, L_local, D] —> all-to-all —> [B, L_global, D/cp] —> conv1d + SSM kernel —> [B, L_global, D/cp] —> all-to-all —> [B, L_local, D]
This module is intentionally not a subclass of nn.Module because it owns
no trainable parameters. It holds references to the Mamba mixer’s parameters
and slices them in the forward path so that gradients flow back to the full
(unsliced) parameters.
Module Contents
Classes
Functions
API
Hidden-parallel context parallelism for a Mamba2 mixer layer.
This class does not own trainable parameters. It stores a reference to the mixer module and accesses its parameters (conv1d, dt_bias, A_log, D) on the fly so that gradients propagate to the original (full) parameters and FSDP-managed DTensor replacements are picked up correctly.
DualChunkSwap reordering is always undone before the SSM kernel and redone
after, because both TE CP (p2p) and PyTorch’s context_parallel(allgather)
reorder sequence chunks for load balancing.
Parameters:
Context-parallel process group.
Total number of SSM heads (before any parallelism).
Dimension per head.
Number of SSM groups (for grouped B/C states).
SSM state dimension.
Reference to the Mamba mixer module (owns conv1d, dt_bias, A_log, D).
Repeat group states for CP ranks when n_groups < cp_size.
[B, L, n_groups * d_state] -> [B, L, n_groups * repeat * d_state]
Also supports THD 2D input [T, n_groups * d_state].
Slice a conv1d parameter (weight or bias) along its channel dimension.
Parameter slicing is done in the forward path so that gradients backpropagate to the original (full) parameters.
Slice a per-head vector parameter for the current CP rank.
Slice A_log for the current CP rank.
Slice D for the current CP rank.
Slice conv1d.bias for the current CP rank.
Bias shape: [conv_dim]. Returns [conv_dim_local].
Slice conv1d.weight for the current CP rank.
Weight shape: [conv_dim, 1, kernel_size] where
conv_dim = d_inner + 2 * n_groups * d_state.
Returns [conv_dim_local, kernel_size] (squeezed for causal_conv1d kernel).
Slice dt_bias for the current CP rank.
Redistribute SSM output from hidden-sharded back to sequence-sharded layout.
Redistribute from sequence-sharded to hidden-sharded layout, undoing DualChunkSwap.
Bases: Function
Autograd wrapper around torch.distributed.all_to_all_single.
For equal-sized splits the all-to-all operation is its own inverse, so the backward pass is simply another all-to-all on the same group.
Functional entry-point for the autograd-aware all-to-all.
Transform from sequence-sharded to hidden-sharded layout (batch-first).
Parameters:
Tensor of shape [B, L_local, H] (BSHD) or [T, H] (THD)
where H is the full hidden dimension on this rank.
Context-parallel process group.
Batch size B (needed to recover dimensions after reshape).
Returns: torch.Tensor
Tensor of shape [B, L_global, H / cp_size] (BSHD) or [T, H / cp_size] (THD).
Transform from hidden-sharded to sequence-sharded layout (batch-first).
This is the inverse of :func:_all_to_all_cp2hp.
Parameters:
Tensor of shape [B, L_global, H_local] (BSHD) or [T, H_local] (THD)
where H_local = H / cp_size.
Context-parallel process group.
Batch size B.
Returns: torch.Tensor
Tensor of shape [B, L_local, H] (BSHD) or [T, H] (THD)
Rearrange tokens from rank-major to sequence-major order after all-to-all.
After _all_to_all_cp2hp on packed 2-D data the token layout along
the sequence dimension is::
[rank0_seq0 | rank0_seq1 | … | rank1_seq0 | rank1_seq1 | …]
This function rearranges to::
[rank0_seq0 | rank1_seq0 | … | rank0_seq1 | rank1_seq1 | …]
so that each sequence’s tokens are contiguous (required by the
_undo_attention_load_balancing reorder that follows).
Parameters:
2-D tensor [T_global, H].
Local (pre-all-to-all) cumulative sequence lengths.
Context-parallel world size.
Returns: torch.Tensor
Rearranged 2-D tensor with the same shape.
Reorder from sequential back to DualChunkSwap for attention.
Inverse of :func:_undo_attention_load_balancing.
Inverse of :func:_deinterleave_packed_seqs.
Rearranges from sequence-major back to rank-major order before the
inverse all-to-all in post_conv_ssm.
Reorder equal-sized chunks of a tensor according to order.
Parameters:
[B, L, H] (BSHD) or [T, H] (THD).
Permutation indices (length must equal the number of chunks).
If provided, reorder per-sequence on dim=0 (THD).
Reorder from DualChunkSwap to sequential for SSM processing.