core.ssm.mamba_layer#
Module Contents#
Classes#
Configuration class for specifying the submodules of a Mamba layer. |
|
A single Mamba layer. |
API#
- class core.ssm.mamba_layer.MambaLayerSubmodules#
Configuration class for specifying the submodules of a Mamba layer.
This class defines the structure and default implementations for various components of a Mamba layer, allowing for flexible customization of the layer’s architecture.
- Parameters:
norm (Union[ModuleSpec, type]) – Specification for the input layer normalization.
mixer (Union[ModuleSpec, type]) – Specification for the along-sequence mixing mechanism.
mamba_bda (Union[ModuleSpec, type]) – Specification for the bias-dropout-add operation after the mixer.
- norm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- mixer: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- mamba_bda: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- sharded_state_dict_keys_map: Dict[str, str]#
‘field(…)’
- class core.ssm.mamba_layer.MambaLayer(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: core.ssm.mamba_layer.MambaLayerSubmodules,
- layer_number: int = 1,
- residual_in_fp32=False,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
- pp_layer_offset: int = 0,
Bases:
megatron.core.transformer.module.GraphableMegatronModuleA single Mamba layer.
Mamba layer takes input with size [s, b, h] and returns an output of the same size.
Initialization
Initialize Mamba Layer.
- mamba_state_shapes_per_request() Tuple[Tuple[int], Tuple[int]]#
Returns the Mamba conv and ssm states shapes per request.
- forward(
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- rotary_pos_emb: Optional[torch.Tensor] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
Perform a forward pass through the Mamba layer.
This method implements the core computation of a Mamba layer, including the convolution and the selective SSM/SSD.
- Parameters:
hidden_states (Tensor) – Input tensor of shape [s, b, h] where s is sequence length, b is batch size, and h is hidden size.
attention_mask (Tensor) – Mask tensor for self-attention. Not used by this layer.
inference_context (BaseInferenceContext, optional) – Parameters for inference-time optimizations.
rotary_pos_emb (Tensor, optional) – Rotary positional embeddings.
- Returns:
Transformed hidden states of shape [s, b, h].
- Return type:
output (Tensor)
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: tuple = (),
- metadata: Optional[dict] = None,
Generate a sharded state dictionary for the mamba layer.
- Parameters:
prefix (str, optional) – Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional) – Tuple of sharding offsets.
metadata (Optional[dict], optional) – Additional metadata for sharding.
- Returns:
A dictionary containing the sharded state of the mamba layer.
- Return type:
ShardedStateDict
- _te_cuda_graph_replay(*args, **kwargs)#
CUDA graph replay for this layer and microbatch
self.current_microbatchusing TE interface. TransformerEngine versions>=1.10 allow keyword arguments with CUDA graph. However, CUDA graph accepts only Tensor inputs. Hence,inference_contextis excluded from input list.
- _should_call_local_cudagraph(*args, **kwargs)#
Check if we should call the local cudagraph path.