core.ssm.mamba_block#
Module Contents#
Classes#
A class for the module specs for the MambaStack. |
|
Constructor for the MambaStack class. |
API#
- class core.ssm.mamba_block.MambaStackSubmodules#
A class for the module specs for the MambaStack.
- mamba_layer: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- attention_layer: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- mlp_layer: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- moe_layer: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- class core.ssm.mamba_block.MambaStack(
- config: megatron.core.transformer.TransformerConfig,
- submodules: core.ssm.mamba_block.MambaStackSubmodules,
- residual_in_fp32=False,
- pre_process: bool = True,
- hybrid_attention_ratio: float = 0.0,
- hybrid_mlp_ratio: float = 0.0,
- hybrid_override_pattern: str = None,
- post_layer_norm: bool = True,
- post_process: bool = True,
- device=None,
- dtype=None,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
Bases:
megatron.core.transformer.module.MegatronModuleConstructor for the MambaStack class.
- Parameters:
config (TransformerConfig) – the model configuration
submodules (MambaStackSubmodules) – the submodules for the stack
residual_in_fp32 (bool, optional) – whether to do residual connections in fp32. Defaults to False.
pre_process (bool, optional) – whether to include an embedding layer. Defaults to True.
hybrid_attention_ratio (float, optional) – the target ratio of attention layers to total layers. Defaults to 0.0.
hybrid_mlp_ratio (float, optional) – the target ratio of mlp layers to total layers. Defaults to 0.0.
hybrid_override_pattern (str, optional) – the hybrid layer pattern to override with. Defaults to None.
post_layer_norm (bool, optional) – whether to include a final layer norm. Defaults to True.
post_process (bool, optional) – whether to include an output layer. Defaults to True.
device (optional) – the device to use. Defaults to None.
dtype (optional) – the data type to use. Defaults to None.
pg_collection (ProcessGroupCollection) – the required model communication process groups to use.
Initialization
- _select_layers_for_pipeline_parallel(layer_type_list)#
- set_input_tensor(input_tensor: torch.Tensor)#
Set input tensor to be used instead of forward()’s input.
When doing pipeline parallelism the input from the previous stage comes from communication, not from the input, so the model’s forward_step_func won’t have it. This function is thus used by internal code to bypass the input provided by the forward_step_func
- mamba_state_shapes_per_request() Optional[Tuple[Tuple[int], Tuple[int]]]#
Returns the Mamba conv and ssm states shapes per input sequence if this block contains Mamba layers (this may not be the case with PP > 1).
- forward(
- hidden_states: Union[torch.Tensor, megatron.core.utils.WrappedTensor],
- attention_mask: torch.Tensor,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- rotary_pos_emb: Optional[torch.Tensor] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
Forward function of the MambaStack class.
It either returns the Loss values if labels are given or the final hidden units
- Parameters:
hidden_states (Union[Tensor, WrappedTensor]) – the input tensor. Can be passed as a WrappedTensor during inference to avoid an obsolete reference in the calling function.
attention_mask (Tensor) – the attention mask.
inference_context (BaseInferenceContext) – the inference parameters.
rotary_pos_emb (Tensor, optional) – the rotary positional embeddings. Defaults to None.
- Returns:
the output tensor.
- Return type:
Tensor
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: Optional[tuple] = None,
- metadata: Optional[dict] = None,
Returns a sharded state dictionary for the current object.
This function constructs a sharded state dictionary by iterating over the layers in the current object, computing the sharded state dictionary for each layer, and combining the results into a single dictionary.
- Parameters:
prefix (str) – The prefix to use for the state dictionary keys.
sharded_offsets (tuple) – The sharded offsets to use for the state dictionary.
metadata (dict) – Additional metadata to use when computing the sharded state dictionary.
- Returns:
The sharded state dictionary for the current object.
- Return type:
dict