core.models.mamba.mamba_model#
Module Contents#
Classes#
Mamba language model. |
API#
- class core.models.mamba.mamba_model.MambaModel(
- config: megatron.core.transformer.TransformerConfig,
- mamba_stack_spec: megatron.core.transformer.spec_utils.ModuleSpec,
- vocab_size: int,
- max_sequence_length: int,
- pre_process: bool = True,
- hybrid_attention_ratio: float = 0.0,
- hybrid_mlp_ratio: float = 0.0,
- hybrid_override_pattern: str = None,
- post_process: bool = True,
- fp16_lm_cross_entropy: bool = False,
- parallel_output: bool = True,
- share_embeddings_and_output_weights: bool = False,
- position_embedding_type: Literal[learned_absolute, rope, none] = 'none',
- rotary_percent: float = 1.0,
- rotary_base: int = 10000,
- scatter_embedding_sequence_parallel: bool = True,
- seq_len_interpolation_factor: Optional[float] = None,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
Bases:
megatron.core.models.common.language_module.language_module.LanguageModuleMamba language model.
- Parameters:
config (TransformerConfig) – Model config
mamba_stack_spec (ModuleSpec) – Specifies the modules to use for the various layer types
vocab_size (int) – Vocabulary size
max_sequence_length (int) – maximum size of sequence. This is used for positional embedding
pre_process (bool, optional) – Include embedding layer (used with pipeline parallelism). Defaults to True.
hybrid_attention_ratio (float, optional) – The target ratio of attention layers to total layers
hybrid_mlp_ratio (float, optional) – The target ratio of mlp layers to total layers
hybrid_override_pattern (str, optional) – The hybrid layer pattern to override with
post_process (bool, optional) – Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional) – Defaults to False.
parallel_output (bool, optional) – Do not gather the outputs, keep them split across tensor parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional) – When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope,none], optional) – Position embedding type. Defaults to ‘none’.
rotary_percent (float, optional) – Percent of rotary dimension to use for rotary position embeddings. Ignored unless position_embedding_type is ‘rope’. Defaults to 1.0.
rotary_base (int, optional) – Base period for rotary position embeddings. Ignored unless position_embedding_type is ‘rope’. Defaults to 10000.
seq_len_interpolation_factor (Optional[float], optional) – scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None.
pg_collection (ProcessGroupCollection, optional) – Model communication process groups.
Initialization
- set_input_tensor(input_tensor: torch.Tensor) None#
Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
- Parameters:
input_tensor (Tensor) – Sets the input tensor for the model.
- forward(
- input_ids: torch.Tensor,
- position_ids: torch.Tensor,
- attention_mask: torch.Tensor,
- decoder_input: torch.Tensor = None,
- labels: torch.Tensor = None,
- inference_context: megatron.core.inference.contexts.BaseInferenceContext = None,
- runtime_gather_output: Optional[bool] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
Forward function of the Mamba model. This function passes the input tensors through the embedding layer, and then the decoder and finally into the post processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units