bridge.models.mamba.mamba_builder#

Module Contents#

Classes#

MambaModelConfig

Configuration for a Megatron Core Mamba (SSM) model.

MambaModelBuilder

Builder to construct Megatron Core Mamba models.

Functions#

transformer_engine_mamba_stack_spec

Return the default Mamba stack spec with Transformer Engine layers.

modelopt_mamba_stack_spec

Mamba stack specification for quantization with ModelOpt.

get_default_mamba_stack_spec

Determine the most appropriate Mamba stack specification based on configuration.

Data#

API#

bridge.models.mamba.mamba_builder.logger#

‘getLogger(…)’

bridge.models.mamba.mamba_builder.transformer_engine_mamba_stack_spec() megatron.core.transformer.ModuleSpec#

Return the default Mamba stack spec with Transformer Engine layers.

This is a named function (not a lambda) to allow proper serialization and reconstruction from checkpoints. Named functions can be imported via their module path, unlike lambdas.

Returns:

Default Mamba stack specification from megatron.core

bridge.models.mamba.mamba_builder.modelopt_mamba_stack_spec() megatron.core.transformer.ModuleSpec#

Mamba stack specification for quantization with ModelOpt.

Uses Norm instead of TENorm and ColumnParallelLinear/RowParallelLinear instead of TE layers to enable proper quantizer insertion by ModelOpt.

Returns:

Module specification for quantization-ready Mamba stack

Return type:

ModuleSpec

bridge.models.mamba.mamba_builder.get_default_mamba_stack_spec(
config: MambaModelConfig,
) megatron.core.transformer.ModuleSpec#

Determine the most appropriate Mamba stack specification based on configuration.

Parameters:

config – Mamba configuration object

Returns:

Appropriate module specification based on config

Return type:

ModuleSpec

class bridge.models.mamba.mamba_builder.MambaModelConfig#

Bases: megatron.bridge.models.common.ModelConfig

Configuration for a Megatron Core Mamba (SSM) model.

This is purely a configuration object. All model construction logic lives in MambaModelBuilder.

Contains a TransformerConfig alongside Mamba-specific parameters. Attributes on the embedded transformer config are accessible directly on this object via __getattr__/__setattr__ proxying.

Supports hybrid SSM/attention architectures via hybrid_attention_ratio, hybrid_mlp_ratio, and hybrid_override_pattern.

.. note:: vocab_size must be set before passing this config to MambaModelBuilder.

builder: ClassVar[str]#

‘megatron.bridge.models.mamba.MambaModelBuilder’

transformer: megatron.bridge.models.transformer_config.TransformerConfig#

None

fp16_lm_cross_entropy: bool#

False

parallel_output: bool#

True

share_embeddings_and_output_weights: bool#

False

hybrid_attention_ratio: float#

0.0

hybrid_mlp_ratio: float#

0.0

hybrid_override_pattern: str | None#

None

seq_length: int#

8192

position_embedding_type: Literal[learned_absolute, rope, none]#

‘none’

rotary_percent: float#

1.0

rotary_base: int#

10000

seq_len_interpolation_factor: float | None#

None

make_vocab_size_divisible_by: int#

128

mamba_stack_spec: megatron.core.transformer.ModuleSpec | Callable[[], megatron.core.transformer.ModuleSpec] | Callable[[bridge.models.mamba.mamba_builder.MambaModelConfig], megatron.core.transformer.ModuleSpec]#

None

vocab_size: int | None#

None

should_pad_vocab: bool#

False

__getattr__(name: str, /) Any#
__setattr__(name: str, value: Any, /) None#
class bridge.models.mamba.mamba_builder.MambaModelBuilder(
model_config: bridge.models.mamba.mamba_builder.MambaModelConfig,
)#

Bases: megatron.bridge.models.common.ModelBuilder[megatron.core.models.mamba.MambaModel, bridge.models.mamba.mamba_builder.MambaModelConfig]

Builder to construct Megatron Core Mamba models.

.. rubric:: Example

transformer_cfg = TransformerConfig(num_layers=32, hidden_size=4096, …) model_cfg = MambaModelConfig(transformer=transformer_cfg, vocab_size=32000, seq_length=2048, …)

Single stage (e.g. inference)

model = MambaModelBuilder(model_cfg).build_model(pg_collection)

Distributed training

models = MambaModelBuilder(model_cfg).build_distributed_models(pg_collection)

Initialization

build_model(
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
pre_process: bool | None = None,
post_process: bool | None = None,
vp_stage: int | None = None,
) megatron.core.models.mamba.MambaModel#

Build a single MCoreMambaModel stage.

Parameters:
  • pg_collection – Process groups for distributed training

  • pre_process – Include embedding layer

  • post_process – Include output layer

  • vp_stage – Virtual pipeline stage

Returns:

The constructed model

.. note:: Virtual pipeline model parallelism is not supported for Mamba models.

build_distributed_models(
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
ddp_config: megatron.core.distributed.DistributedDataParallelConfig | None = None,
overlap_param_gather_with_optimizer_step: bool = False,
use_megatron_fsdp: bool = False,
use_torch_fsdp2: bool = False,
wrap_with_ddp: bool = True,
data_parallel_random_init: bool = True,
mixed_precision_wrapper: Callable[[Any, megatron.core.transformer.MegatronModule], megatron.core.transformer.MegatronModule] | None = Float16Module,
model_type: megatron.core.enums.ModelType = ModelType.encoder_or_decoder,
) list[megatron.core.models.mamba.MambaModel]#

Build model stages and wrap for distributed training.

Parameters:
  • pg_collection – Model communication process groups.

  • ddp_config – DistributedDataParallel configuration

  • overlap_param_gather_with_optimizer_step – Whether to overlap parameter gather with optimizer step.

  • use_megatron_fsdp – Whether to use Megatron FSDP

  • use_torch_fsdp2 – Whether to use Torch FSDP 2.0

  • wrap_with_ddp – Set to False to skip the DDP/FSDP wrapper.

  • data_parallel_random_init – Whether to use data parallel random initialization

  • mixed_precision_wrapper – Mixed precision wrapper, e.g. Float16Module

  • model_type – Deprecated flag, only used for backwards compatibility.

Returns:

List of model stages.