bridge.models.mamba.mamba_builder#
Module Contents#
Classes#
Configuration for a Megatron Core Mamba (SSM) model. |
|
Builder to construct Megatron Core Mamba models. |
Functions#
Return the default Mamba stack spec with Transformer Engine layers. |
|
Mamba stack specification for quantization with ModelOpt. |
|
Determine the most appropriate Mamba stack specification based on configuration. |
Data#
API#
- bridge.models.mamba.mamba_builder.logger#
‘getLogger(…)’
- bridge.models.mamba.mamba_builder.transformer_engine_mamba_stack_spec() megatron.core.transformer.ModuleSpec#
Return the default Mamba stack spec with Transformer Engine layers.
This is a named function (not a lambda) to allow proper serialization and reconstruction from checkpoints. Named functions can be imported via their module path, unlike lambdas.
- Returns:
Default Mamba stack specification from megatron.core
- bridge.models.mamba.mamba_builder.modelopt_mamba_stack_spec() megatron.core.transformer.ModuleSpec#
Mamba stack specification for quantization with ModelOpt.
Uses Norm instead of TENorm and ColumnParallelLinear/RowParallelLinear instead of TE layers to enable proper quantizer insertion by ModelOpt.
- Returns:
Module specification for quantization-ready Mamba stack
- Return type:
ModuleSpec
- bridge.models.mamba.mamba_builder.get_default_mamba_stack_spec(
- config: MambaModelConfig,
Determine the most appropriate Mamba stack specification based on configuration.
- Parameters:
config – Mamba configuration object
- Returns:
Appropriate module specification based on config
- Return type:
ModuleSpec
- class bridge.models.mamba.mamba_builder.MambaModelConfig#
Bases:
megatron.bridge.models.common.ModelConfigConfiguration for a Megatron Core Mamba (SSM) model.
This is purely a configuration object. All model construction logic lives in
MambaModelBuilder.Contains a
TransformerConfigalongside Mamba-specific parameters. Attributes on the embeddedtransformerconfig are accessible directly on this object via__getattr__/__setattr__proxying.Supports hybrid SSM/attention architectures via
hybrid_attention_ratio,hybrid_mlp_ratio, andhybrid_override_pattern... note::
vocab_sizemust be set before passing this config toMambaModelBuilder.- builder: ClassVar[str]#
‘megatron.bridge.models.mamba.MambaModelBuilder’
- transformer: megatron.bridge.models.transformer_config.TransformerConfig#
None
- fp16_lm_cross_entropy: bool#
False
- parallel_output: bool#
True
False
- hybrid_attention_ratio: float#
0.0
- hybrid_mlp_ratio: float#
0.0
- hybrid_override_pattern: str | None#
None
- seq_length: int#
8192
- position_embedding_type: Literal[learned_absolute, rope, none]#
‘none’
- rotary_percent: float#
1.0
- rotary_base: int#
10000
- seq_len_interpolation_factor: float | None#
None
- make_vocab_size_divisible_by: int#
128
- mamba_stack_spec: megatron.core.transformer.ModuleSpec | Callable[[], megatron.core.transformer.ModuleSpec] | Callable[[bridge.models.mamba.mamba_builder.MambaModelConfig], megatron.core.transformer.ModuleSpec]#
None
- vocab_size: int | None#
None
- should_pad_vocab: bool#
False
- __getattr__(name: str, /) Any#
- __setattr__(name: str, value: Any, /) None#
- class bridge.models.mamba.mamba_builder.MambaModelBuilder(
- model_config: bridge.models.mamba.mamba_builder.MambaModelConfig,
Bases:
megatron.bridge.models.common.ModelBuilder[megatron.core.models.mamba.MambaModel,bridge.models.mamba.mamba_builder.MambaModelConfig]Builder to construct Megatron Core Mamba models.
.. rubric:: Example
transformer_cfg = TransformerConfig(num_layers=32, hidden_size=4096, …) model_cfg = MambaModelConfig(transformer=transformer_cfg, vocab_size=32000, seq_length=2048, …)
Single stage (e.g. inference)
model = MambaModelBuilder(model_cfg).build_model(pg_collection)
Distributed training
models = MambaModelBuilder(model_cfg).build_distributed_models(pg_collection)
Initialization
- build_model(
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
- pre_process: bool | None = None,
- post_process: bool | None = None,
- vp_stage: int | None = None,
Build a single
MCoreMambaModelstage.- Parameters:
pg_collection – Process groups for distributed training
pre_process – Include embedding layer
post_process – Include output layer
vp_stage – Virtual pipeline stage
- Returns:
The constructed model
.. note:: Virtual pipeline model parallelism is not supported for Mamba models.
- build_distributed_models(
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
- ddp_config: megatron.core.distributed.DistributedDataParallelConfig | None = None,
- overlap_param_gather_with_optimizer_step: bool = False,
- use_megatron_fsdp: bool = False,
- use_torch_fsdp2: bool = False,
- wrap_with_ddp: bool = True,
- data_parallel_random_init: bool = True,
- mixed_precision_wrapper: Callable[[Any, megatron.core.transformer.MegatronModule], megatron.core.transformer.MegatronModule] | None = Float16Module,
- model_type: megatron.core.enums.ModelType = ModelType.encoder_or_decoder,
Build model stages and wrap for distributed training.
- Parameters:
pg_collection – Model communication process groups.
ddp_config – DistributedDataParallel configuration
overlap_param_gather_with_optimizer_step – Whether to overlap parameter gather with optimizer step.
use_megatron_fsdp – Whether to use Megatron FSDP
use_torch_fsdp2 – Whether to use Torch FSDP 2.0
wrap_with_ddp – Set to False to skip the DDP/FSDP wrapper.
data_parallel_random_init – Whether to use data parallel random initialization
mixed_precision_wrapper – Mixed precision wrapper, e.g.
Float16Modulemodel_type – Deprecated flag, only used for backwards compatibility.
- Returns:
List of model stages.