nemo_automodel.components.distributed.mamba_cp#

Context parallelism for Mamba/SSM layers using a hidden-parallel strategy.

Instead of splitting the sequence across CP ranks (as attention CP does), this module uses an all-to-all redistribution so that each CP rank processes the full sequence but only a subset of heads (d_inner / cp_size). The data flow is::

[B, L_local, D]  -->  all-to-all  -->  [B, L_global, D/cp]
    -->  conv1d + SSM kernel  -->
[B, L_global, D/cp]  -->  all-to-all  -->  [B, L_local, D]

This module is intentionally not a subclass of nn.Module because it owns no trainable parameters. It holds references to the Mamba mixer’s parameters and slices them in the forward path so that gradients flow back to the full (unsliced) parameters.

Module Contents#

Classes#

_AllToAll

Autograd wrapper around torch.distributed.all_to_all_single.

MambaContextParallel

Hidden-parallel context parallelism for a Mamba2 mixer layer.

Functions#

_all_to_all

Functional entry-point for the autograd-aware all-to-all.

_all_to_all_cp2hp

Transform from sequence-sharded to hidden-sharded layout (batch-first).

_all_to_all_hp2cp

Transform from hidden-sharded to sequence-sharded layout (batch-first).

_reorder_chunks

Reorder equal-sized chunks of a tensor according to order.

_deinterleave_packed_seqs

Rearrange tokens from rank-major to sequence-major order after all-to-all.

_reinterleave_packed_seqs

Inverse of :func:_deinterleave_packed_seqs.

_undo_attention_load_balancing

Reorder from DualChunkSwap to sequential for SSM processing.

_redo_attention_load_balancing

Reorder from sequential back to DualChunkSwap for attention.

API#

class nemo_automodel.components.distributed.mamba_cp._AllToAll#

Bases: torch.autograd.Function

Autograd wrapper around torch.distributed.all_to_all_single.

For equal-sized splits the all-to-all operation is its own inverse, so the backward pass is simply another all-to-all on the same group.

static forward(
ctx,
input_: torch.Tensor,
group: torch.distributed.ProcessGroup,
) torch.Tensor#
static backward(ctx, grad_output: torch.Tensor)#
nemo_automodel.components.distributed.mamba_cp._all_to_all(
input_: torch.Tensor,
group: torch.distributed.ProcessGroup,
) torch.Tensor#

Functional entry-point for the autograd-aware all-to-all.

nemo_automodel.components.distributed.mamba_cp._all_to_all_cp2hp(
input_: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
batch_size: int,
) torch.Tensor#

Transform from sequence-sharded to hidden-sharded layout (batch-first).

Parameters:
  • input_ – Tensor of shape [B, L_local, H] (BSHD) or [T, H] (THD) where H is the full hidden dimension on this rank.

  • cp_group – Context-parallel process group.

  • batch_size – Batch size B (needed to recover dimensions after reshape).

Returns:

Tensor of shape [B, L_global, H / cp_size] (BSHD) or [T, H / cp_size] (THD).

nemo_automodel.components.distributed.mamba_cp._all_to_all_hp2cp(
input_: torch.Tensor,
cp_group: torch.distributed.ProcessGroup,
batch_size: int,
) torch.Tensor#

Transform from hidden-sharded to sequence-sharded layout (batch-first).

This is the inverse of :func:_all_to_all_cp2hp.

Parameters:
  • input_ – Tensor of shape [B, L_global, H_local] (BSHD) or [T, H_local] (THD) where H_local = H / cp_size.

  • cp_group – Context-parallel process group.

  • batch_size – Batch size B.

Returns:

Tensor of shape [B, L_local, H] (BSHD) or [T, H] (THD) where L_local = L_global / cp_size and H = H_local * cp_size.

nemo_automodel.components.distributed.mamba_cp._reorder_chunks(
input_: torch.Tensor,
order: list[int],
cu_seqlens: torch.Tensor | None = None,
) torch.Tensor#

Reorder equal-sized chunks of a tensor according to order.

Parameters:
  • input_ – [B, L, H] (BSHD) or [T, H] (THD).

  • order – Permutation indices (length must equal the number of chunks).

  • cu_seqlens – If provided, reorder per-sequence on dim=0 (THD).

nemo_automodel.components.distributed.mamba_cp._deinterleave_packed_seqs(
input_: torch.Tensor,
cu_seqlens: torch.Tensor,
cp_size: int,
) torch.Tensor#

Rearrange tokens from rank-major to sequence-major order after all-to-all.

After _all_to_all_cp2hp on packed 2-D data the token layout along the sequence dimension is::

[rank0_seq0 | rank0_seq1 | ... | rank1_seq0 | rank1_seq1 | ...]

This function rearranges to::

[rank0_seq0 | rank1_seq0 | ... | rank0_seq1 | rank1_seq1 | ...]

so that each sequence’s tokens are contiguous (required by the _undo_attention_load_balancing reorder that follows).

Parameters:
  • input_ – 2-D tensor [T_global, H].

  • cu_seqlens – Local (pre-all-to-all) cumulative sequence lengths.

  • cp_size – Context-parallel world size.

Returns:

Rearranged 2-D tensor with the same shape.

nemo_automodel.components.distributed.mamba_cp._reinterleave_packed_seqs(
input_: torch.Tensor,
cu_seqlens: torch.Tensor,
cp_size: int,
) torch.Tensor#

Inverse of :func:_deinterleave_packed_seqs.

Rearranges from sequence-major back to rank-major order before the inverse all-to-all in post_conv_ssm.

nemo_automodel.components.distributed.mamba_cp._undo_attention_load_balancing(
input_: torch.Tensor,
cp_size: int,
cu_seqlens: torch.Tensor | None = None,
) torch.Tensor#

Reorder from DualChunkSwap to sequential for SSM processing.

nemo_automodel.components.distributed.mamba_cp._redo_attention_load_balancing(
input_: torch.Tensor,
cp_size: int,
cu_seqlens: torch.Tensor | None = None,
) torch.Tensor#

Reorder from sequential back to DualChunkSwap for attention.

Inverse of :func:_undo_attention_load_balancing.

class nemo_automodel.components.distributed.mamba_cp.MambaContextParallel(
cp_group: torch.distributed.ProcessGroup,
num_heads: int,
head_dim: int,
n_groups: int,
d_state: int,
mixer: torch.nn.Module,
)#

Hidden-parallel context parallelism for a Mamba2 mixer layer.

This class does not own trainable parameters. It stores a reference to the mixer module and accesses its parameters (conv1d, dt_bias, A_log, D) on the fly so that gradients propagate to the original (full) parameters and FSDP-managed DTensor replacements are picked up correctly.

DualChunkSwap reordering is always undone before the SSM kernel and redone after, because both TE CP (p2p) and PyTorch’s context_parallel(allgather) reorder sequence chunks for load balancing.

Parameters:
  • cp_group – Context-parallel process group.

  • num_heads – Total number of SSM heads (before any parallelism).

  • head_dim – Dimension per head.

  • n_groups – Number of SSM groups (for grouped B/C states).

  • d_state – SSM state dimension.

  • mixer – Reference to the Mamba mixer module (owns conv1d, dt_bias, A_log, D).

Initialization

pre_conv_ssm(
projected_states: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
) torch.Tensor#

Redistribute from sequence-sharded to hidden-sharded layout, undoing DualChunkSwap.

post_conv_ssm(
output: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
) torch.Tensor#

Redistribute SSM output from hidden-sharded back to sequence-sharded layout.

get_conv1d_weight() torch.Tensor#

Slice conv1d.weight for the current CP rank.

Weight shape: [conv_dim, 1, kernel_size] where conv_dim = d_inner + 2 * n_groups * d_state. Returns [conv_dim_local, kernel_size] (squeezed for causal_conv1d kernel).

get_conv1d_bias() torch.Tensor#

Slice conv1d.bias for the current CP rank.

Bias shape: [conv_dim]. Returns [conv_dim_local].

get_dt_bias() torch.Tensor#

Slice dt_bias for the current CP rank.

get_A_log() torch.Tensor#

Slice A_log for the current CP rank.

get_D() torch.Tensor#

Slice D for the current CP rank.

_repeat_group_state(state: torch.Tensor) torch.Tensor#

Repeat group states for CP ranks when n_groups < cp_size.

[B, L, n_groups * d_state] -> [B, L, n_groups * repeat * d_state] Also supports THD 2D input [T, n_groups * d_state].

_slice_vector_param(param: torch.Tensor) torch.Tensor#

Slice a per-head vector parameter for the current CP rank.

_slice_conv_param(param: torch.Tensor) torch.Tensor#

Slice a conv1d parameter (weight or bias) along its channel dimension.

Parameter slicing is done in the forward path so that gradients backpropagate to the original (full) parameters.