core.ssm.mamba_mixer#
Module Contents#
Classes#
RMSNormGated with sharded state dict. |
|
Contains the module specs for the input and output linear layers. |
|
|
Functions#
Builds a factory that splits a given ShardedTensor into several independent chunks. |
Data#
API#
- core.ssm.mamba_mixer.logger#
‘getLogger(…)’
- class core.ssm.mamba_mixer.ExtendedRMSNorm(/, *args, **kw)#
Bases:
mamba_ssm.ops.triton.layernorm_gated.RMSNormRMSNormGated with sharded state dict.
Initialization
- sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#
Sharding along axis 0, bias not sharded
- class core.ssm.mamba_mixer.MambaMixerSubmodules#
Contains the module specs for the input and output linear layers.
- in_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- out_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- class core.ssm.mamba_mixer.MambaMixer(
- config: megatron.core.transformer.TransformerConfig,
- submodules: core.ssm.mamba_mixer.MambaMixerSubmodules,
- d_model,
- d_conv=4,
- conv_init=None,
- expand=2,
- A_init_range=(1, 16),
- D_has_hdim=False,
- rmsnorm=True,
- norm_before_gate=False,
- dt_min=0.001,
- dt_max=0.1,
- dt_init='random',
- dt_scale=1.0,
- dt_init_floor=0.0001,
- bias=False,
- conv_bias=True,
- chunk_size=128,
- layer_number=None,
- use_mem_eff_path=None,
- d_state=None,
- headdim=None,
- ngroups=None,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
- pp_layer_offset: int = 0,
Bases:
megatron.core.transformer.module.MegatronModule- Parameters:
config – The config of the model.
submodules – Contains the module specs for the input and output linear layers.
d_model – The hidden size of the model.
d_state – The state size of the SSM.
d_conv – The number of channels in the causal convolution.
conv_init – The initialization range for the causal convolution weights.
expand – The expansion factor for the SSM.
headdim – The hidden size of each attention head.
ngroups – The number of attention heads.
A_init_range – The initialization range for the attention weights.
D_has_hdim – Whether the D parameter has the same number of dimensions as the hidden state.
rmsnorm – Whether to use root mean square normalization.
norm_before_gate – Whether to apply normalization before the gating mechanism.
dt_min – The minimum value of the dt parameter.
dt_max – The maximum value of the dt parameter.
dt_init – The initialization value of the dt parameter.
dt_scale – The scaling factor for the dt parameter.
dt_init_floor – The minimum value of the dt parameter after initialization.
bias – Whether to use bias in the linear layers.
conv_bias – Whether to use bias in the causal convolution.
chunk_size – The chunk size for the fused kernel.
use_mem_eff_path – Whether to use the memory-efficient path for the Mamba model.
layer_number – The layer number of this Mamba layer.
pg_collection – The required process groups to use for tensor model parallel and context parallel.
Initialization
- forward(
- hidden_states,
- inference_context=None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
hidden_states: (nL, B, D) / (L B D) Returns: same shape as hidden_states
- dynamic_inference(
- hidden_states: torch.Tensor,
- context: megatron.core.inference.contexts.DynamicInferenceContext,
Executes dynamic inference by separating decode and prefill requests and running them independently. Also runs the chunked prefill request independently if it exists.
- decode(
- hidden_states,
- conv_state,
- ssm_state,
- batch_indices: Optional[torch.Tensor] = None,
Performs inference step for decoding.
- ssm_training(zxBCdt: torch.Tensor) torch.Tensor#
Performs SSM computation for training step.
Uses the memory-efficient kernel
mamba_split_conv1d_scan_combinedwhich reduces the size of forward activations stored for backprop and therefore reduces memory pressure during training.
- ssm_prefill(
- zxBCdt: torch.Tensor,
- conv_state: Optional[torch.Tensor],
- ssm_state: Optional[torch.Tensor],
- seq_idx: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- return_varlen_states: bool = False,
- batch_indices: Optional[torch.Tensor] = None,
- is_chunked_prefill: bool = False,
Performs SSM computation for inference prefill step.
- Parameters:
zxBCdt – The input tensor of shape (l, b, d), which is a concatenation of z, x, B, C, and dt projections.
conv_state – The convolution state tensor for inference.
ssm_state – The selective scan state tensor for inference.
seq_idx – A map from token index to request index for variable-length sequences.
cu_seqlens – Cumulative sequence lengths for variable-length sequences.
return_varlen_states – Whether to return variable-length states from the SSM kernel.
batch_indices – A map from batch id to position in the Mamba state tensors for dynamic inference.
is_chunked_prefill – Whether the request is a chunked prefill request.
- Returns:
The output tensor of shape (l, b, d).
- ssm_decode(
- zxBCdt: torch.Tensor,
- conv_state: torch.Tensor,
- ssm_state: torch.Tensor,
- batch_indices: Optional[torch.Tensor] = None,
Performs SSM computation for inference decode step.
- Parameters:
zxBCdt – The input tensor of shape (l, b, d), which is a concatenation of z, x, B, C, and dt projections. For decoding, l must be 1.
conv_state – The convolution state tensor for inference.
ssm_state – The selective scan state tensor for inference.
batch_indices – A map from batch id to position in the Mamba state tensors for dynamic inference.
- Returns:
The output tensor of shape (l, b, d).
- _get_varlen_generation_state(
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
Constructs the variable length generation state for non-decode dynamic inference.
The returned state includes the following:
seq_idx(Tensor): A map from token idx to request idx.cu_seqlens(Tensor): The cumulative sequence lengths.return_varlen_states(bool): Whether to return a varlen states tensor formamba_chunk_scan_combined.Returns empty state for training, static inference, or decode-only dynamic inference.
- Parameters:
inference_context (InferenceContext) – The inference context.
- Returns:
A tuple of (
seq_idx,cu_seqlens,return_varlen_states)
- mamba_state_shapes_per_request() Tuple[Tuple[int], Tuple[int]]#
Returns the Mamba conv and ssm states shapes per request.
- _get_states_from_cache(
- inference_context,
- batch_size,
- *,
- inference_params=None,
Initializes or retrieves the SSM state tensors from the cache.
At the start of any inference (at the prefill step), if there is no cache or if the cached batch size has changed, then new tensors are initialized and stored in the cache. Otherwise the existing tensors are retrieved from the cache and zeroed out.
- sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#
Provide a sharded state dictionary for distributed checkpointing.
- core.ssm.mamba_mixer._split_tensor_factory(
- orig_sh_ten: megatron.core.dist_checkpointing.ShardedTensor,
- split_sections: List[int],
- split_names: List[str],
- split_dim: int,
Builds a factory that splits a given ShardedTensor into several independent chunks.