bridge.models.mamba.mamba_provider#

Module Contents#

Classes#

MambaModelProvider

Configuration and provider for Megatron Core Mamba models.

Functions#

modelopt_mamba_stack_spec

Mamba stack specification for quantization with ModelOpt.

transformer_engine_mamba_stack_spec

Return the default Mamba stack spec with Transformer Engine layers.

get_default_mamba_stack_spec

Determine the most appropriate Mamba stack specification based on configuration.

Data#

API#

bridge.models.mamba.mamba_provider.logger#

‘getLogger(…)’

bridge.models.mamba.mamba_provider.modelopt_mamba_stack_spec(
config: MambaModelProvider,
) megatron.core.transformer.ModuleSpec#

Mamba stack specification for quantization with ModelOpt.

Uses Norm instead of TENorm and ColumnParallelLinear/RowParallelLinear instead of TE layers to enable proper quantizer insertion by ModelOpt.

Parameters:

config – Mamba configuration object

Returns:

Module specification for quantization-ready Mamba stack

Return type:

ModuleSpec

bridge.models.mamba.mamba_provider.transformer_engine_mamba_stack_spec() megatron.core.transformer.ModuleSpec#

Return the default Mamba stack spec with Transformer Engine layers.

This is a named function (not a lambda) to allow proper serialization and reconstruction from checkpoints. Named functions can be imported via their module path, unlike lambdas.

Returns:

Default Mamba stack specification from megatron.core

bridge.models.mamba.mamba_provider.get_default_mamba_stack_spec(
config: MambaModelProvider,
) megatron.core.transformer.ModuleSpec#

Determine the most appropriate Mamba stack specification based on configuration.

Parameters:

config – Mamba configuration object

Returns:

Appropriate module specification based on config

Return type:

ModuleSpec

class bridge.models.mamba.mamba_provider.MambaModelProvider#

Bases: megatron.bridge.models.transformer_config.TransformerConfig, megatron.bridge.models.model_provider.ModelProviderMixin[megatron.core.models.mamba.MambaModel]

Configuration and provider for Megatron Core Mamba models.

This class extends TransformerConfig with Mamba-specific parameters and provides a method to instantiate configured Mamba models.

fp16_lm_cross_entropy: bool#

False

parallel_output: bool#

True

share_embeddings_and_output_weights: bool#

False

params_dtype: torch.dtype#

None

fp16: bool#

False

bf16: bool#

True

num_layers: int#

2

mamba_num_groups: int#

8

num_attention_heads: int#

1

hybrid_attention_ratio: float#

0.0

hybrid_mlp_ratio: float#

0.0

hybrid_override_pattern: Optional[str]#

None

seq_length: int#

8192

position_embedding_type: Literal[learned_absolute, rope, none]#

‘none’

rotary_percent: float#

1.0

rotary_base: int#

10000

seq_len_interpolation_factor: Optional[float]#

None

apply_rope_fusion: bool#

True

make_vocab_size_divisible_by: int#

128

gated_linear_unit: bool#

False

normalization: str#

‘RMSNorm’

add_bias_linear: bool#

False

hidden_dropout: float#

0.0

attention_dropout: float#

0.0

layernorm_epsilon: float#

1e-05

attention_backend: megatron.core.transformer.enums.AttnBackend#

None

deallocate_pipeline_outputs: bool#

True

bias_dropout_fusion: bool#

True

cross_entropy_loss_fusion: bool#

True

mamba_stack_spec: Union[megatron.core.transformer.ModuleSpec, Callable[[], megatron.core.transformer.ModuleSpec], Callable[[bridge.models.mamba.mamba_provider.MambaModelProvider], megatron.core.transformer.ModuleSpec]]#

None

vocab_size: Optional[int]#

None

should_pad_vocab: bool#

False

hf_model_id: Optional[str]#

None

_pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection]#

None

Optional HuggingFace model identifier associated with this provider.

restore_modelopt_state: bool#

False

provide(
pre_process=None,
post_process=None,
vp_stage=None,
) megatron.core.models.mamba.MambaModel#

Configure and instantiate a Megatron Core Mamba model based on this configuration.

Parameters:
  • pre_process – Whether to include pre-processing in the model, defaults to first pipeline stage

  • post_process – Whether to include post-processing in the model, defaults to last pipeline stage

  • vp_stage – Virtual pipeline stage

Returns:

Configured Megatron Core Mamba model instance

Return type:

MCoreMambaModel