core.ssm.mamba_layer#

Module Contents#

Classes#

MambaLayerSubmodules

Configuration class for specifying the submodules of a Mamba layer.

MambaLayer

A single Mamba layer.

API#

class core.ssm.mamba_layer.MambaLayerSubmodules#

Configuration class for specifying the submodules of a Mamba layer.

This class defines the structure and default implementations for various components of a Mamba layer, allowing for flexible customization of the layer’s architecture.

Parameters:
  • norm (Union[ModuleSpec, type]) – Specification for the input layer normalization.

  • mixer (Union[ModuleSpec, type]) – Specification for the along-sequence mixing mechanism.

  • mamba_bda (Union[ModuleSpec, type]) – Specification for the bias-dropout-add operation after the mixer.

norm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

mixer: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

mamba_bda: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

sharded_state_dict_keys_map: Dict[str, str]#

‘field(…)’

class core.ssm.mamba_layer.MambaLayer(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: core.ssm.mamba_layer.MambaLayerSubmodules,
layer_number: int = 1,
residual_in_fp32=False,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
pp_layer_offset: int = 0,
)#

Bases: megatron.core.transformer.module.GraphableMegatronModule

A single Mamba layer.

Mamba layer takes input with size [s, b, h] and returns an output of the same size.

Initialization

Initialize Mamba Layer.

mamba_state_shapes_per_request() Tuple[Tuple[int], Tuple[int]]#

Returns the Mamba conv and ssm states shapes per request.

forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
)#

Perform a forward pass through the Mamba layer.

This method implements the core computation of a Mamba layer, including the convolution and the selective SSM/SSD.

Parameters:
  • hidden_states (Tensor) – Input tensor of shape [s, b, h] where s is sequence length, b is batch size, and h is hidden size.

  • attention_mask (Tensor) – Mask tensor for self-attention. Not used by this layer.

  • inference_context (BaseInferenceContext, optional) – Parameters for inference-time optimizations.

  • rotary_pos_emb (Tensor, optional) – Rotary positional embeddings.

Returns:

Transformed hidden states of shape [s, b, h].

Return type:

output (Tensor)

sharded_state_dict(
prefix: str = '',
sharded_offsets: tuple = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Generate a sharded state dictionary for the mamba layer.

Parameters:
  • prefix (str, optional) – Prefix to be added to all keys in the state dict.

  • sharded_offsets (tuple, optional) – Tuple of sharding offsets.

  • metadata (Optional[dict], optional) – Additional metadata for sharding.

Returns:

A dictionary containing the sharded state of the mamba layer.

Return type:

ShardedStateDict

_te_cuda_graph_replay(*args, **kwargs)#

CUDA graph replay for this layer and microbatch self.current_microbatch using TE interface. TransformerEngine versions>=1.10 allow keyword arguments with CUDA graph. However, CUDA graph accepts only Tensor inputs. Hence, inference_context is excluded from input list.

_should_call_local_cudagraph(*args, **kwargs)#

Check if we should call the local cudagraph path.