core.ssm.mamba_mixer#

Module Contents#

Classes#

ExtendedRMSNorm

RMSNormGated with sharded state dict.

MambaMixerSubmodules

Contains the module specs for the input and output linear layers.

MambaMixer

param config:

The config of the model.

param submodules:

Contains the module specs for the input and output linear layers.

param d_model:

The hidden size of the model.

param d_state:

The state size of the SSM.

param d_conv:

The number of channels in the causal convolution.

param conv_init:

The initialization range for the causal convolution weights.

param expand:

The expansion factor for the SSM.

param headdim:

The hidden size of each attention head.

param ngroups:

The number of attention heads.

param A_init_range:

The initialization range for the attention weights.

param D_has_hdim:

Whether the D parameter has the same number of dimensions as the hidden state.

param rmsnorm:

Whether to use root mean square normalization.

param norm_before_gate:

Whether to apply normalization before the gating mechanism.

param dt_min:

The minimum value of the dt parameter.

param dt_max:

The maximum value of the dt parameter.

param dt_init:

The initialization value of the dt parameter.

param dt_scale:

The scaling factor for the dt parameter.

param dt_init_floor:

The minimum value of the dt parameter after initialization.

param bias:

Whether to use bias in the linear layers.

param conv_bias:

Whether to use bias in the causal convolution.

param chunk_size:

The chunk size for the fused kernel.

param use_mem_eff_path:

Whether to use the memory-efficient path for the Mamba model.

param layer_number:

The layer number of this Mamba layer.

param pg_collection:

The required process groups to use for tensor model parallel and context parallel.

Functions#

_split_tensor_factory

Builds a factory that splits a given ShardedTensor into several independent chunks.

Data#

API#

core.ssm.mamba_mixer.logger#

‘getLogger(…)’

class core.ssm.mamba_mixer.ExtendedRMSNorm(/, *args, **kw)#

Bases: mamba_ssm.ops.triton.layernorm_gated.RMSNorm

RMSNormGated with sharded state dict.

Initialization

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Sharding along axis 0, bias not sharded

class core.ssm.mamba_mixer.MambaMixerSubmodules#

Contains the module specs for the input and output linear layers.

in_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

out_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

class core.ssm.mamba_mixer.MambaMixer(
config: megatron.core.transformer.TransformerConfig,
submodules: core.ssm.mamba_mixer.MambaMixerSubmodules,
d_model,
d_conv=4,
conv_init=None,
expand=2,
A_init_range=(1, 16),
D_has_hdim=False,
rmsnorm=True,
norm_before_gate=False,
dt_min=0.001,
dt_max=0.1,
dt_init='random',
dt_scale=1.0,
dt_init_floor=0.0001,
bias=False,
conv_bias=True,
chunk_size=128,
layer_number=None,
use_mem_eff_path=None,
d_state=None,
headdim=None,
ngroups=None,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
pp_layer_offset: int = 0,
)#

Bases: megatron.core.transformer.module.MegatronModule

Parameters:
  • config – The config of the model.

  • submodules – Contains the module specs for the input and output linear layers.

  • d_model – The hidden size of the model.

  • d_state – The state size of the SSM.

  • d_conv – The number of channels in the causal convolution.

  • conv_init – The initialization range for the causal convolution weights.

  • expand – The expansion factor for the SSM.

  • headdim – The hidden size of each attention head.

  • ngroups – The number of attention heads.

  • A_init_range – The initialization range for the attention weights.

  • D_has_hdim – Whether the D parameter has the same number of dimensions as the hidden state.

  • rmsnorm – Whether to use root mean square normalization.

  • norm_before_gate – Whether to apply normalization before the gating mechanism.

  • dt_min – The minimum value of the dt parameter.

  • dt_max – The maximum value of the dt parameter.

  • dt_init – The initialization value of the dt parameter.

  • dt_scale – The scaling factor for the dt parameter.

  • dt_init_floor – The minimum value of the dt parameter after initialization.

  • bias – Whether to use bias in the linear layers.

  • conv_bias – Whether to use bias in the causal convolution.

  • chunk_size – The chunk size for the fused kernel.

  • use_mem_eff_path – Whether to use the memory-efficient path for the Mamba model.

  • layer_number – The layer number of this Mamba layer.

  • pg_collection – The required process groups to use for tensor model parallel and context parallel.

Initialization

forward(
hidden_states,
inference_context=None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
)#

hidden_states: (nL, B, D) / (L B D) Returns: same shape as hidden_states

dynamic_inference(
hidden_states: torch.Tensor,
context: megatron.core.inference.contexts.DynamicInferenceContext,
)#

Executes dynamic inference by separating decode and prefill requests and running them independently. Also runs the chunked prefill request independently if it exists.

decode(
hidden_states,
conv_state,
ssm_state,
batch_indices: Optional[torch.Tensor] = None,
) Tuple[torch.Tensor, torch.Tensor]#

Performs inference step for decoding.

ssm_training(zxBCdt: torch.Tensor) torch.Tensor#

Performs SSM computation for training step.

Uses the memory-efficient kernel mamba_split_conv1d_scan_combined which reduces the size of forward activations stored for backprop and therefore reduces memory pressure during training.

ssm_prefill(
zxBCdt: torch.Tensor,
conv_state: Optional[torch.Tensor],
ssm_state: Optional[torch.Tensor],
seq_idx: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
return_varlen_states: bool = False,
batch_indices: Optional[torch.Tensor] = None,
is_chunked_prefill: bool = False,
) torch.Tensor#

Performs SSM computation for inference prefill step.

Parameters:
  • zxBCdt – The input tensor of shape (l, b, d), which is a concatenation of z, x, B, C, and dt projections.

  • conv_state – The convolution state tensor for inference.

  • ssm_state – The selective scan state tensor for inference.

  • seq_idx – A map from token index to request index for variable-length sequences.

  • cu_seqlens – Cumulative sequence lengths for variable-length sequences.

  • return_varlen_states – Whether to return variable-length states from the SSM kernel.

  • batch_indices – A map from batch id to position in the Mamba state tensors for dynamic inference.

  • is_chunked_prefill – Whether the request is a chunked prefill request.

Returns:

The output tensor of shape (l, b, d).

ssm_decode(
zxBCdt: torch.Tensor,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
batch_indices: Optional[torch.Tensor] = None,
) torch.Tensor#

Performs SSM computation for inference decode step.

Parameters:
  • zxBCdt – The input tensor of shape (l, b, d), which is a concatenation of z, x, B, C, and dt projections. For decoding, l must be 1.

  • conv_state – The convolution state tensor for inference.

  • ssm_state – The selective scan state tensor for inference.

  • batch_indices – A map from batch id to position in the Mamba state tensors for dynamic inference.

Returns:

The output tensor of shape (l, b, d).

_get_varlen_generation_state(
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
) Tuple[torch.Tensor, torch.Tensor, bool]#

Constructs the variable length generation state for non-decode dynamic inference.

The returned state includes the following: seq_idx (Tensor): A map from token idx to request idx. cu_seqlens (Tensor): The cumulative sequence lengths. return_varlen_states (bool): Whether to return a varlen states tensor for mamba_chunk_scan_combined.

Returns empty state for training, static inference, or decode-only dynamic inference.

Parameters:

inference_context (InferenceContext) – The inference context.

Returns:

A tuple of (seq_idx, cu_seqlens, return_varlen_states)

mamba_state_shapes_per_request() Tuple[Tuple[int], Tuple[int]]#

Returns the Mamba conv and ssm states shapes per request.

_get_states_from_cache(
inference_context,
batch_size,
*,
inference_params=None,
)#

Initializes or retrieves the SSM state tensors from the cache.

At the start of any inference (at the prefill step), if there is no cache or if the cached batch size has changed, then new tensors are initialized and stored in the cache. Otherwise the existing tensors are retrieved from the cache and zeroed out.

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Provide a sharded state dictionary for distributed checkpointing.

core.ssm.mamba_mixer._split_tensor_factory(
orig_sh_ten: megatron.core.dist_checkpointing.ShardedTensor,
split_sections: List[int],
split_names: List[str],
split_dim: int,
) megatron.core.dist_checkpointing.mapping.ShardedTensorFactory#

Builds a factory that splits a given ShardedTensor into several independent chunks.