bridge.models.minimax_m2.minimax_m2_bridge#

Module Contents#

Classes#

_FullDimQKNormMapping

TP-sharded mapping for full-dimension QK norm weights.

MiniMaxM2Bridge

Megatron Bridge for MiniMax-M2 MoE Causal LM.

Functions#

_dequant_fp8_blockwise

Block-wise FP8 dequantization: out = fp8_val * scale_inv per 128x128 block.

Data#

API#

bridge.models.minimax_m2.minimax_m2_bridge._FP8_BLOCK_SIZE#

128

bridge.models.minimax_m2.minimax_m2_bridge._dequant_fp8_blockwise(
weight: torch.Tensor,
scale_inv: torch.Tensor,
) torch.Tensor#

Block-wise FP8 dequantization: out = fp8_val * scale_inv per 128x128 block.

class bridge.models.minimax_m2.minimax_m2_bridge._FullDimQKNormMapping#

Bases: megatron.bridge.models.conversion.param_mapping.MegatronParamMapping[torch.Tensor]

TP-sharded mapping for full-dimension QK norm weights.

HF weight shape: [num_heads * head_dim] Megatron weight shape per rank: [num_heads_per_partition * head_dim]

Uses broadcast-then-slice instead of scatter_to_tp_ranks because the _FullDimRMSNorm module may reside on CPU / meta device where NCCL scatter is not available.

hf_to_megatron(
hf_weights: torch.Tensor,
megatron_module: torch.nn.Module,
) torch.Tensor#
megatron_to_hf(
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[torch.nn.Module],
) Dict[str, torch.Tensor]#
class bridge.models.minimax_m2.minimax_m2_bridge.MiniMaxM2Bridge#

Bases: megatron.bridge.models.conversion.model_bridge.MegatronModelBridge

Megatron Bridge for MiniMax-M2 MoE Causal LM.

MiniMax-M2 is a sparse MoE model (256 experts, top-8 routing with sigmoid scoring and expert bias correction). Use the native transformers >= 5.0 implementation (no trust_remote_code required).

On-disk checkpoint format (both the HF hub checkpoint and models saved with save_pretrained) uses the legacy block_sparse_moe key prefix with per-expert w1 (gate), w3 (up), and w2 (down) weight tensors. The in-memory model API uses mlp / gate_up_proj / down_proj but serialization reverts to the legacy layout.

QK normalization: MiniMax-M2 applies full-dimension RMSNorm to Q/K (weight shape = num_heads * head_dim) before splitting into heads. Megatron’s built-in QK norm is per-head (weight shape = head_dim). This bridge uses a custom layer spec (minimax_m2_layer_spec) with FullDimQNorm/FullDimKNorm that normalizes over the full partition dimension. With TP > 1 the sum-of-squares is all-reduced across TP ranks so the RMS denominator matches the single-GPU case.

Known limitations: - MTP (Multi-Token Prediction) modules are not mapped.

.. rubric:: Example

from megatron.bridge import AutoBridge bridge = AutoBridge.from_hf_pretrained(“MiniMaxAI/MiniMax-M2”) provider = bridge.to_megatron_provider()

provider_bridge(hf_pretrained)#

Convert HuggingFace MiniMax-M2 config to GPTModelProvider.

maybe_modify_loaded_hf_weight(
hf_param: str | dict[str, str],
hf_state_dict: collections.abc.Mapping[str, torch.Tensor],
) torch.Tensor#

Load HF weights with FP8 block-wise dequantization when needed.

MiniMax-M2 stores linear weights as float8_e4m3fn with per-block scale factors in <key>_scale_inv tensors (128x128 blocks).

_load_and_dequant(
key: str,
hf_state_dict: collections.abc.Mapping[str, torch.Tensor],
) torch.Tensor#
mapping_registry() megatron.bridge.models.conversion.mapping_registry.MegatronMappingRegistry#