core.ssm.mamba_mixer#

Module Contents#

Classes#

ExtendedRMSNorm

RMSNormGated with sharded state dict.

MambaMixerSubmodules

Contains the module specs for the input and output linear layers.

MambaMixer

param config:

The config of the model.

param submodules:

Contains the module specs for the input and output linear layers.

param d_model:

The hidden size of the model.

param d_state:

The state size of the SSM.

param d_conv:

The number of channels in the causal convolution.

param conv_init:

The initialization range for the causal convolution weights.

param expand:

The expansion factor for the SSM.

param headdim:

The hidden size of each attention head.

param ngroups:

The number of attention heads.

param A_init_range:

The initialization range for the attention weights.

param D_has_hdim:

Whether the D parameter has the same number of dimensions as the hidden state.

param rmsnorm:

Whether to use root mean square normalization.

param norm_before_gate:

Whether to apply normalization before the gating mechanism.

param dt_min:

The minimum value of the dt parameter.

param dt_max:

The maximum value of the dt parameter.

param dt_init:

The initialization value of the dt parameter.

param dt_scale:

The scaling factor for the dt parameter.

param dt_init_floor:

The minimum value of the dt parameter after initialization.

param bias:

Whether to use bias in the linear layers.

param conv_bias:

Whether to use bias in the causal convolution.

param chunk_size:

The chunk size for the Mamba SSM fused kernel.

param use_mem_eff_path:

Whether to use the memory-efficient path for the Mamba model.

param layer_number:

The layer number of this Mamba layer.

param pg_collection:

The required process groups to use for tensor model parallel and context parallel.

Functions#

_split_tensor_factory

Builds a factory that splits a given ShardedTensor into several independent chunks.

_check_mamba_sequence_packing_support

Checks whether causal_conv1d and mamba_ssm support sequence packing.

Data#

API#

core.ssm.mamba_mixer.logger#

‘getLogger(…)’

class core.ssm.mamba_mixer.ExtendedRMSNorm(/, *args, **kw)#

Bases: mamba_ssm.ops.triton.layernorm_gated.RMSNorm

RMSNormGated with sharded state dict.

Initialization

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Sharding along axis 0, bias not sharded

class core.ssm.mamba_mixer.MambaMixerSubmodules#

Contains the module specs for the input and output linear layers.

in_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

out_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

class core.ssm.mamba_mixer.MambaMixer(
config: megatron.core.transformer.TransformerConfig,
submodules: core.ssm.mamba_mixer.MambaMixerSubmodules,
d_model,
d_conv=4,
conv_init=None,
expand=2,
A_init_range=(1, 16),
D_has_hdim=False,
rmsnorm=True,
norm_before_gate=False,
dt_min=0.001,
dt_max=0.1,
dt_init='random',
dt_scale=1.0,
dt_init_floor=0.0001,
bias=False,
conv_bias=True,
chunk_size=128,
layer_number=None,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
pp_layer_offset: int = 0,
)#

Bases: megatron.core.transformer.module.MegatronModule

Parameters:
  • config – The config of the model.

  • submodules – Contains the module specs for the input and output linear layers.

  • d_model – The hidden size of the model.

  • d_state – The state size of the SSM.

  • d_conv – The number of channels in the causal convolution.

  • conv_init – The initialization range for the causal convolution weights.

  • expand – The expansion factor for the SSM.

  • headdim – The hidden size of each attention head.

  • ngroups – The number of attention heads.

  • A_init_range – The initialization range for the attention weights.

  • D_has_hdim – Whether the D parameter has the same number of dimensions as the hidden state.

  • rmsnorm – Whether to use root mean square normalization.

  • norm_before_gate – Whether to apply normalization before the gating mechanism.

  • dt_min – The minimum value of the dt parameter.

  • dt_max – The maximum value of the dt parameter.

  • dt_init – The initialization value of the dt parameter.

  • dt_scale – The scaling factor for the dt parameter.

  • dt_init_floor – The minimum value of the dt parameter after initialization.

  • bias – Whether to use bias in the linear layers.

  • conv_bias – Whether to use bias in the causal convolution.

  • chunk_size – The chunk size for the Mamba SSM fused kernel.

  • use_mem_eff_path – Whether to use the memory-efficient path for the Mamba model.

  • layer_number – The layer number of this Mamba layer.

  • pg_collection – The required process groups to use for tensor model parallel and context parallel.

Initialization

forward(
hidden_states,
inference_context=None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
)#

hidden_states: (nL, B, D) / (L B D) Returns: same shape as hidden_states

_dynamic_inference(
hidden_states: torch.Tensor,
context: megatron.core.inference.contexts.DynamicInferenceContext,
)#

Executes dynamic inference by separating decode and prefill requests and running them independently.

_dynamic_inference_prefill(
zxBCdt: torch.Tensor,
context: megatron.core.inference.contexts.DynamicInferenceContext,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
mamba_layer_idx: Optional[int] = None,
) torch.Tensor#

Helper to run dynamic inference prefill.

All prefill requests (including chunked prefill) are processed together through the unified varlen path. Uses precomputed metadata from MambaMetadata.update() to avoid .item() calls and data-dependent control flow, enabling CUDA graph compatibility.

Intermediate state extraction (for Mamba prefix caching) is performed inside _ssm_prefill via pre-allocated output buffers, making it fully CUDA graph compatible.

_decode(
hidden_states,
conv_state,
ssm_state,
batch_indices: Optional[torch.Tensor] = None,
) Tuple[torch.Tensor, torch.Tensor]#

Performs inference step for decoding.

_ssm_training(
zxBCdt: torch.Tensor,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
) torch.Tensor#

Performs SSM computation for training step.

Uses the memory-efficient kernel mamba_split_conv1d_scan_combined which reduces the size of forward activations stored for backprop and therefore reduces memory pressure during training.

_ssm_prefill(
zxBCdt: torch.Tensor,
conv_state: Optional[torch.Tensor],
ssm_state: Optional[torch.Tensor],
seq_idx: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
batch_indices: Optional[torch.Tensor] = None,
intermediate_chunk_indices: Optional[torch.Tensor] = None,
intermediate_abs_positions: Optional[torch.Tensor] = None,
intermediate_ssm_out: Optional[torch.Tensor] = None,
intermediate_conv_out: Optional[torch.Tensor] = None,
conv_gather_offsets: Optional[torch.Tensor] = None,
cu_chunk_seqlens: Optional[torch.Tensor] = None,
last_chunk_indices: Optional[torch.Tensor] = None,
seq_idx_for_varlen: Optional[torch.Tensor] = None,
cu_seqlens_list: Optional[List[int]] = None,
real_token_count: Optional[int] = None,
conv_seq_idx: Optional[torch.Tensor] = None,
conv_seq_start: Optional[torch.Tensor] = None,
) torch.Tensor#

Performs SSM computation for inference prefill step.

Parameters:
  • zxBCdt – The input tensor of shape (l, b, d), which is a concatenation of z, x, B, C, and dt projections.

  • conv_state – The convolution state tensor for inference.

  • ssm_state – The selective scan state tensor for inference.

  • seq_idx – A map from token index to request index for variable-length sequences.

  • cu_seqlens – Cumulative sequence lengths for variable-length sequences.

  • batch_indices – A map from batch id to position in the Mamba state tensors for dynamic inference.

  • intermediate_chunk_indices – Pre-allocated tensor of chunk indices for intermediate state extraction (fixed size, padded with 0).

  • intermediate_abs_positions – Pre-allocated tensor of absolute token positions for conv state extraction (fixed size, padded with d_conv).

  • intermediate_ssm_out – Output buffer for extracted SSM states [max_intermediate_count, *ssm_shape].

  • intermediate_conv_out – Output buffer for extracted conv states [max_intermediate_count, *conv_shape].

  • conv_gather_offsets – Constant tensor [-d_conv, …, -1] for gathering conv states.

  • cu_chunk_seqlens – Precomputed chunk boundaries from MambaMetadata.

  • last_chunk_indices – Precomputed last chunk index per sequence.

  • seq_idx_for_varlen – Precomputed request ID per chunk.

  • cu_seqlens_list – Python list of cumulative sequence lengths (avoids .item()).

  • real_token_count – Number of real (non-padding) tokens.

  • conv_seq_idx – Precomputed per-token request ID for Triton conv1d.

  • conv_seq_start – Precomputed per-token request start for Triton conv1d.

Returns:

Output tensor of shape (l, b, d). Intermediate states (if any) are written directly to intermediate_ssm_out and intermediate_conv_out.

_ssm_decode(
zxBCdt: torch.Tensor,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
batch_indices: Optional[torch.Tensor] = None,
intermediate_conv_state: Optional[torch.Tensor] = None,
intermediate_ssm_state: Optional[torch.Tensor] = None,
) torch.Tensor#

Performs SSM computation for inference decode step.

Parameters:
  • zxBCdt – The input tensor of shape (b, s, d), which is a concatenation of z, x, B, C, and dt projections. s is the sequence length (1 + num_speculative_tokens).

  • conv_state – The convolution state tensor for inference.

  • ssm_state – The selective scan state tensor for inference.

  • batch_indices – A map from batch id to position in the Mamba state tensors.

  • intermediate_conv_state – Optional buffer for storing conv state at each sequence step (for speculative decoding rollback).

  • intermediate_ssm_state – Optional buffer for storing SSM state at each sequence step (for speculative decoding rollback).

Returns:

The output tensor of shape (b, s, d).

mamba_state_shapes_per_request() Tuple[Tuple[int], Tuple[int]]#

Returns the Mamba conv and ssm states shapes per request.

_get_states_from_cache(
inference_context,
batch_size,
*,
inference_params=None,
)#

Initializes or retrieves the SSM state tensors from the cache.

At the start of any inference (at the prefill step), if there is no cache or if the cached batch size has changed, then new tensors are initialized and stored in the cache. Otherwise the existing tensors are retrieved from the cache and zeroed out.

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Provide a sharded state dictionary for distributed checkpointing.

core.ssm.mamba_mixer._split_tensor_factory(
orig_sh_ten: megatron.core.dist_checkpointing.ShardedTensor,
split_sections: List[int],
split_names: List[str],
split_dim: int,
) megatron.core.dist_checkpointing.mapping.ShardedTensorFactory#

Builds a factory that splits a given ShardedTensor into several independent chunks.

core.ssm.mamba_mixer._check_mamba_sequence_packing_support(
for_inference_not_training: bool = True,
) Tuple[bool, Optional[str]]#

Checks whether causal_conv1d and mamba_ssm support sequence packing.