nemo_automodel.components.distributed.mamba_cp#
Context parallelism for Mamba/SSM layers using a hidden-parallel strategy.
Instead of splitting the sequence across CP ranks (as attention CP does), this module uses an all-to-all redistribution so that each CP rank processes the full sequence but only a subset of heads (d_inner / cp_size). The data flow is::
[B, L_local, D] --> all-to-all --> [B, L_global, D/cp]
--> conv1d + SSM kernel -->
[B, L_global, D/cp] --> all-to-all --> [B, L_local, D]
This module is intentionally not a subclass of nn.Module because it owns
no trainable parameters. It holds references to the Mamba mixer’s parameters
and slices them in the forward path so that gradients flow back to the full
(unsliced) parameters.
Module Contents#
Classes#
Autograd wrapper around |
|
Hidden-parallel context parallelism for a Mamba2 mixer layer. |
Functions#
Functional entry-point for the autograd-aware all-to-all. |
|
Transform from sequence-sharded to hidden-sharded layout (batch-first). |
|
Transform from hidden-sharded to sequence-sharded layout (batch-first). |
|
Reorder equal-sized chunks of a tensor according to order. |
|
Rearrange tokens from rank-major to sequence-major order after all-to-all. |
|
Inverse of :func: |
|
Reorder from DualChunkSwap to sequential for SSM processing. |
|
Reorder from sequential back to DualChunkSwap for attention. |
API#
- class nemo_automodel.components.distributed.mamba_cp._AllToAll#
Bases:
torch.autograd.FunctionAutograd wrapper around
torch.distributed.all_to_all_single.For equal-sized splits the all-to-all operation is its own inverse, so the backward pass is simply another all-to-all on the same group.
- static forward(
- ctx,
- input_: torch.Tensor,
- group: torch.distributed.ProcessGroup,
- static backward(ctx, grad_output: torch.Tensor)#
- nemo_automodel.components.distributed.mamba_cp._all_to_all(
- input_: torch.Tensor,
- group: torch.distributed.ProcessGroup,
Functional entry-point for the autograd-aware all-to-all.
- nemo_automodel.components.distributed.mamba_cp._all_to_all_cp2hp(
- input_: torch.Tensor,
- cp_group: torch.distributed.ProcessGroup,
- batch_size: int,
Transform from sequence-sharded to hidden-sharded layout (batch-first).
- Parameters:
input_ – Tensor of shape
[B, L_local, H](BSHD) or[T, H](THD) where H is the full hidden dimension on this rank.cp_group – Context-parallel process group.
batch_size – Batch size
B(needed to recover dimensions after reshape).
- Returns:
Tensor of shape
[B, L_global, H / cp_size](BSHD) or[T, H / cp_size](THD).
- nemo_automodel.components.distributed.mamba_cp._all_to_all_hp2cp(
- input_: torch.Tensor,
- cp_group: torch.distributed.ProcessGroup,
- batch_size: int,
Transform from hidden-sharded to sequence-sharded layout (batch-first).
This is the inverse of :func:
_all_to_all_cp2hp.- Parameters:
input_ – Tensor of shape
[B, L_global, H_local](BSHD) or[T, H_local](THD) whereH_local = H / cp_size.cp_group – Context-parallel process group.
batch_size – Batch size
B.
- Returns:
Tensor of shape
[B, L_local, H](BSHD) or[T, H](THD) whereL_local = L_global / cp_sizeandH = H_local * cp_size.
- nemo_automodel.components.distributed.mamba_cp._reorder_chunks(
- input_: torch.Tensor,
- order: list[int],
- cu_seqlens: torch.Tensor | None = None,
Reorder equal-sized chunks of a tensor according to order.
- Parameters:
input_ –
[B, L, H](BSHD) or[T, H](THD).order – Permutation indices (length must equal the number of chunks).
cu_seqlens – If provided, reorder per-sequence on dim=0 (THD).
- nemo_automodel.components.distributed.mamba_cp._deinterleave_packed_seqs(
- input_: torch.Tensor,
- cu_seqlens: torch.Tensor,
- cp_size: int,
Rearrange tokens from rank-major to sequence-major order after all-to-all.
After
_all_to_all_cp2hpon packed 2-D data the token layout along the sequence dimension is::[rank0_seq0 | rank0_seq1 | ... | rank1_seq0 | rank1_seq1 | ...]
This function rearranges to::
[rank0_seq0 | rank1_seq0 | ... | rank0_seq1 | rank1_seq1 | ...]
so that each sequence’s tokens are contiguous (required by the
_undo_attention_load_balancingreorder that follows).- Parameters:
input_ – 2-D tensor
[T_global, H].cu_seqlens – Local (pre-all-to-all) cumulative sequence lengths.
cp_size – Context-parallel world size.
- Returns:
Rearranged 2-D tensor with the same shape.
- nemo_automodel.components.distributed.mamba_cp._reinterleave_packed_seqs(
- input_: torch.Tensor,
- cu_seqlens: torch.Tensor,
- cp_size: int,
Inverse of :func:
_deinterleave_packed_seqs.Rearranges from sequence-major back to rank-major order before the inverse all-to-all in
post_conv_ssm.
- nemo_automodel.components.distributed.mamba_cp._undo_attention_load_balancing(
- input_: torch.Tensor,
- cp_size: int,
- cu_seqlens: torch.Tensor | None = None,
Reorder from DualChunkSwap to sequential for SSM processing.
- nemo_automodel.components.distributed.mamba_cp._redo_attention_load_balancing(
- input_: torch.Tensor,
- cp_size: int,
- cu_seqlens: torch.Tensor | None = None,
Reorder from sequential back to DualChunkSwap for attention.
Inverse of :func:
_undo_attention_load_balancing.
- class nemo_automodel.components.distributed.mamba_cp.MambaContextParallel(
- cp_group: torch.distributed.ProcessGroup,
- num_heads: int,
- head_dim: int,
- n_groups: int,
- d_state: int,
- mixer: torch.nn.Module,
Hidden-parallel context parallelism for a Mamba2 mixer layer.
This class does not own trainable parameters. It stores a reference to the mixer module and accesses its parameters (conv1d, dt_bias, A_log, D) on the fly so that gradients propagate to the original (full) parameters and FSDP-managed DTensor replacements are picked up correctly.
DualChunkSwap reordering is always undone before the SSM kernel and redone after, because both TE CP (p2p) and PyTorch’s
context_parallel(allgather)reorder sequence chunks for load balancing.- Parameters:
cp_group – Context-parallel process group.
num_heads – Total number of SSM heads (before any parallelism).
head_dim – Dimension per head.
n_groups – Number of SSM groups (for grouped B/C states).
d_state – SSM state dimension.
mixer – Reference to the Mamba mixer module (owns conv1d, dt_bias, A_log, D).
Initialization
- pre_conv_ssm(
- projected_states: torch.Tensor,
- cu_seqlens: torch.Tensor | None = None,
Redistribute from sequence-sharded to hidden-sharded layout, undoing DualChunkSwap.
- post_conv_ssm(
- output: torch.Tensor,
- cu_seqlens: torch.Tensor | None = None,
Redistribute SSM output from hidden-sharded back to sequence-sharded layout.
- get_conv1d_weight() torch.Tensor#
Slice
conv1d.weightfor the current CP rank.Weight shape:
[conv_dim, 1, kernel_size]whereconv_dim = d_inner + 2 * n_groups * d_state. Returns[conv_dim_local, kernel_size](squeezed for causal_conv1d kernel).
- get_conv1d_bias() torch.Tensor#
Slice
conv1d.biasfor the current CP rank.Bias shape:
[conv_dim]. Returns[conv_dim_local].
- get_dt_bias() torch.Tensor#
Slice
dt_biasfor the current CP rank.
- get_A_log() torch.Tensor#
Slice
A_logfor the current CP rank.
- get_D() torch.Tensor#
Slice
Dfor the current CP rank.
- _repeat_group_state(state: torch.Tensor) torch.Tensor#
Repeat group states for CP ranks when n_groups < cp_size.
[B, L, n_groups * d_state]->[B, L, n_groups * repeat * d_state]Also supports THD 2D input[T, n_groups * d_state].
- _slice_vector_param(param: torch.Tensor) torch.Tensor#
Slice a per-head vector parameter for the current CP rank.
- _slice_conv_param(param: torch.Tensor) torch.Tensor#
Slice a conv1d parameter (weight or bias) along its channel dimension.
Parameter slicing is done in the forward path so that gradients backpropagate to the original (full) parameters.