bridge.models.minimax_m2.minimax_m2_bridge#

Module Contents#

Classes#

_FullDimQKNormMapping

TP-sharded mapping for full-dimension QK norm weights.

MiniMaxM2Bridge

Megatron Bridge for MiniMax-M2 MoE Causal LM.

Data#

API#

bridge.models.minimax_m2.minimax_m2_bridge.__all__#

[‘MiniMaxM2Bridge’, ‘_FullDimQKNormMapping’, ‘_dequant_fp8_blockwise’]

bridge.models.minimax_m2.minimax_m2_bridge._dequant_fp8_blockwise#

None

class bridge.models.minimax_m2.minimax_m2_bridge._FullDimQKNormMapping#

Bases: megatron.bridge.models.conversion.param_mapping.MegatronParamMapping[torch.Tensor]

TP-sharded mapping for full-dimension QK norm weights.

HF weight shape: [num_heads * head_dim] Megatron weight shape per rank: [num_heads_per_partition * head_dim]

Uses broadcast-then-slice instead of scatter_to_tp_ranks because the _FullDimRMSNorm module may reside on CPU / meta device where NCCL scatter is not available.

hf_to_megatron(
hf_weights: torch.Tensor,
megatron_module: torch.nn.Module,
) torch.Tensor#
megatron_to_hf(
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[torch.nn.Module],
) Dict[str, torch.Tensor]#
class bridge.models.minimax_m2.minimax_m2_bridge.MiniMaxM2Bridge#

Bases: megatron.bridge.models.conversion.model_bridge.MegatronModelBridge

Megatron Bridge for MiniMax-M2 MoE Causal LM.

Also supports MiniMax-M2.5 and MiniMax-M2.7, which share the same model_type (minimax_m2) and MiniMaxM2ForCausalLM architecture.

MiniMax-M2 is a sparse MoE model (256 experts, top-8 routing with sigmoid scoring and expert bias correction). Use the native transformers >= 5.0 implementation (no trust_remote_code required).

On-disk checkpoint format (both the HF hub checkpoint and models saved with save_pretrained) uses the legacy block_sparse_moe key prefix with per-expert w1 (gate), w3 (up), and w2 (down) weight tensors. The in-memory model API uses mlp / gate_up_proj / down_proj but serialization reverts to the legacy layout.

QK normalization: MiniMax-M2 applies full-dimension RMSNorm to Q/K (weight shape = num_heads * head_dim) before splitting into heads. Megatron’s built-in QK norm is per-head (weight shape = head_dim). This bridge uses a custom layer spec (minimax_m2_layer_spec) with FullDimQNorm/FullDimKNorm that normalizes over the full partition dimension. With TP > 1 the sum-of-squares is all-reduced across TP ranks so the RMS denominator matches the single-GPU case.

Known limitations: - MTP (Multi-Token Prediction) modules are not mapped.

.. rubric:: Example

from megatron.bridge import AutoBridge bridge = AutoBridge.from_hf_pretrained(“MiniMaxAI/MiniMax-M2”) provider = bridge.to_megatron_provider()

provider_bridge(hf_pretrained)#

Convert HuggingFace MiniMax-M2 config to GPTModelProvider.

maybe_modify_loaded_hf_weight(
hf_param: str | dict[str, str],
hf_state_dict: collections.abc.Mapping[str, torch.Tensor],
) torch.Tensor | dict[str, torch.Tensor]#

Load HF weights with FP8 block-wise dequantization when needed.

MiniMax-M2 stores linear weights as float8_e4m3fn with per-block scale factors in <key>_scale_inv tensors (128x128 blocks).

_load_and_dequant(
key: str,
hf_state_dict: collections.abc.Mapping[str, torch.Tensor],
) torch.Tensor#
mapping_registry() megatron.bridge.models.conversion.mapping_registry.MegatronMappingRegistry#