core.ssm.mamba_mixer#
Module Contents#
Classes#
RMSNormGated with sharded state dict. |
|
Contains the module specs for the input and output linear layers. |
|
|
Functions#
Builds a factory that splits a given ShardedTensor into several independent chunks. |
|
Checks whether |
Data#
API#
- core.ssm.mamba_mixer.logger#
‘getLogger(…)’
- class core.ssm.mamba_mixer.ExtendedRMSNorm(/, *args, **kw)#
Bases:
mamba_ssm.ops.triton.layernorm_gated.RMSNormRMSNormGated with sharded state dict.
Initialization
- sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#
Sharding along axis 0, bias not sharded
- class core.ssm.mamba_mixer.MambaMixerSubmodules#
Contains the module specs for the input and output linear layers.
- class core.ssm.mamba_mixer.MambaMixer(
- config: megatron.core.transformer.TransformerConfig,
- submodules: core.ssm.mamba_mixer.MambaMixerSubmodules,
- d_model,
- d_conv=4,
- conv_init=None,
- expand=2,
- A_init_range=(1, 16),
- D_has_hdim=False,
- rmsnorm=True,
- norm_before_gate=False,
- dt_min=0.001,
- dt_max=0.1,
- dt_init='random',
- dt_scale=1.0,
- dt_init_floor=0.0001,
- bias=False,
- conv_bias=True,
- chunk_size=128,
- layer_number=None,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
- pp_layer_offset: int = 0,
Bases:
megatron.core.transformer.module.MegatronModule- Parameters:
config – The config of the model.
submodules – Contains the module specs for the input and output linear layers.
d_model – The hidden size of the model.
d_state – The state size of the SSM.
d_conv – The number of channels in the causal convolution.
conv_init – The initialization range for the causal convolution weights.
expand – The expansion factor for the SSM.
headdim – The hidden size of each attention head.
ngroups – The number of attention heads.
A_init_range – The initialization range for the attention weights.
D_has_hdim – Whether the D parameter has the same number of dimensions as the hidden state.
rmsnorm – Whether to use root mean square normalization.
norm_before_gate – Whether to apply normalization before the gating mechanism.
dt_min – The minimum value of the dt parameter.
dt_max – The maximum value of the dt parameter.
dt_init – The initialization value of the dt parameter.
dt_scale – The scaling factor for the dt parameter.
dt_init_floor – The minimum value of the dt parameter after initialization.
bias – Whether to use bias in the linear layers.
conv_bias – Whether to use bias in the causal convolution.
chunk_size – The chunk size for the Mamba SSM fused kernel.
use_mem_eff_path – Whether to use the memory-efficient path for the Mamba model.
layer_number – The layer number of this Mamba layer.
pg_collection – The required process groups to use for tensor model parallel and context parallel.
Initialization
- forward(
- hidden_states,
- inference_context=None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
hidden_states: (nL, B, D) / (L B D) Returns: same shape as hidden_states
- _dynamic_inference(
- hidden_states: torch.Tensor,
- context: megatron.core.inference.contexts.DynamicInferenceContext,
Executes dynamic inference by separating decode and prefill requests and running them independently.
- _dynamic_inference_prefill(
- zxBCdt: torch.Tensor,
- context: megatron.core.inference.contexts.DynamicInferenceContext,
- conv_state: torch.Tensor,
- ssm_state: torch.Tensor,
- mamba_layer_idx: Optional[int] = None,
Helper to run dynamic inference prefill.
All prefill requests (including chunked prefill) are processed together through the unified varlen path. Uses precomputed metadata from MambaMetadata.update() to avoid .item() calls and data-dependent control flow, enabling CUDA graph compatibility.
Intermediate state extraction (for Mamba prefix caching) is performed inside _ssm_prefill via pre-allocated output buffers, making it fully CUDA graph compatible.
- _decode(
- hidden_states,
- conv_state,
- ssm_state,
- batch_indices: Optional[torch.Tensor] = None,
Performs inference step for decoding.
- _ssm_training(
- zxBCdt: torch.Tensor,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
Performs SSM computation for training step.
Uses the memory-efficient kernel
mamba_split_conv1d_scan_combinedwhich reduces the size of forward activations stored for backprop and therefore reduces memory pressure during training.
- _ssm_prefill(
- zxBCdt: torch.Tensor,
- conv_state: Optional[torch.Tensor],
- ssm_state: Optional[torch.Tensor],
- seq_idx: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- batch_indices: Optional[torch.Tensor] = None,
- intermediate_chunk_indices: Optional[torch.Tensor] = None,
- intermediate_abs_positions: Optional[torch.Tensor] = None,
- intermediate_ssm_out: Optional[torch.Tensor] = None,
- intermediate_conv_out: Optional[torch.Tensor] = None,
- conv_gather_offsets: Optional[torch.Tensor] = None,
- cu_chunk_seqlens: Optional[torch.Tensor] = None,
- last_chunk_indices: Optional[torch.Tensor] = None,
- seq_idx_for_varlen: Optional[torch.Tensor] = None,
- cu_seqlens_list: Optional[List[int]] = None,
- real_token_count: Optional[int] = None,
- conv_seq_idx: Optional[torch.Tensor] = None,
- conv_seq_start: Optional[torch.Tensor] = None,
Performs SSM computation for inference prefill step.
- Parameters:
zxBCdt – The input tensor of shape (l, b, d), which is a concatenation of z, x, B, C, and dt projections.
conv_state – The convolution state tensor for inference.
ssm_state – The selective scan state tensor for inference.
seq_idx – A map from token index to request index for variable-length sequences.
cu_seqlens – Cumulative sequence lengths for variable-length sequences.
batch_indices – A map from batch id to position in the Mamba state tensors for dynamic inference.
intermediate_chunk_indices – Pre-allocated tensor of chunk indices for intermediate state extraction (fixed size, padded with 0).
intermediate_abs_positions – Pre-allocated tensor of absolute token positions for conv state extraction (fixed size, padded with d_conv).
intermediate_ssm_out – Output buffer for extracted SSM states [max_intermediate_count, *ssm_shape].
intermediate_conv_out – Output buffer for extracted conv states [max_intermediate_count, *conv_shape].
conv_gather_offsets – Constant tensor [-d_conv, …, -1] for gathering conv states.
cu_chunk_seqlens – Precomputed chunk boundaries from MambaMetadata.
last_chunk_indices – Precomputed last chunk index per sequence.
seq_idx_for_varlen – Precomputed request ID per chunk.
cu_seqlens_list – Python list of cumulative sequence lengths (avoids .item()).
real_token_count – Number of real (non-padding) tokens.
conv_seq_idx – Precomputed per-token request ID for Triton conv1d.
conv_seq_start – Precomputed per-token request start for Triton conv1d.
- Returns:
Output tensor of shape (l, b, d). Intermediate states (if any) are written directly to intermediate_ssm_out and intermediate_conv_out.
- _ssm_decode(
- zxBCdt: torch.Tensor,
- conv_state: torch.Tensor,
- ssm_state: torch.Tensor,
- batch_indices: Optional[torch.Tensor] = None,
- intermediate_conv_state: Optional[torch.Tensor] = None,
- intermediate_ssm_state: Optional[torch.Tensor] = None,
Performs SSM computation for inference decode step.
- Parameters:
zxBCdt – The input tensor of shape (b, s, d), which is a concatenation of z, x, B, C, and dt projections. s is the sequence length (1 + num_speculative_tokens).
conv_state – The convolution state tensor for inference.
ssm_state – The selective scan state tensor for inference.
batch_indices – A map from batch id to position in the Mamba state tensors.
intermediate_conv_state – Optional buffer for storing conv state at each sequence step (for speculative decoding rollback).
intermediate_ssm_state – Optional buffer for storing SSM state at each sequence step (for speculative decoding rollback).
- Returns:
The output tensor of shape (b, s, d).
- mamba_state_shapes_per_request() Tuple[Tuple[int], Tuple[int]]#
Returns the Mamba conv and ssm states shapes per request.
- _get_states_from_cache(
- inference_context,
- batch_size,
- *,
- inference_params=None,
Initializes or retrieves the SSM state tensors from the cache.
At the start of any inference (at the prefill step), if there is no cache or if the cached batch size has changed, then new tensors are initialized and stored in the cache. Otherwise the existing tensors are retrieved from the cache and zeroed out.
- sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#
Provide a sharded state dictionary for distributed checkpointing.
- core.ssm.mamba_mixer._split_tensor_factory(
- orig_sh_ten: megatron.core.dist_checkpointing.ShardedTensor,
- split_sections: List[int],
- split_names: List[str],
- split_dim: int,
Builds a factory that splits a given ShardedTensor into several independent chunks.
- core.ssm.mamba_mixer._check_mamba_sequence_packing_support(
- for_inference_not_training: bool = True,
Checks whether
causal_conv1dandmamba_ssmsupport sequence packing.