bridge.models.minimax_m2.minimax_m2_bridge#
Module Contents#
Classes#
TP-sharded mapping for full-dimension QK norm weights. |
|
Megatron Bridge for MiniMax-M2 MoE Causal LM. |
Functions#
Block-wise FP8 dequantization: out = fp8_val * scale_inv per 128x128 block. |
Data#
API#
- bridge.models.minimax_m2.minimax_m2_bridge._FP8_BLOCK_SIZE#
128
- bridge.models.minimax_m2.minimax_m2_bridge._dequant_fp8_blockwise(
- weight: torch.Tensor,
- scale_inv: torch.Tensor,
Block-wise FP8 dequantization: out = fp8_val * scale_inv per 128x128 block.
- class bridge.models.minimax_m2.minimax_m2_bridge._FullDimQKNormMapping#
Bases:
megatron.bridge.models.conversion.param_mapping.MegatronParamMapping[torch.Tensor]TP-sharded mapping for full-dimension QK norm weights.
HF weight shape:
[num_heads * head_dim]Megatron weight shape per rank:[num_heads_per_partition * head_dim]Uses broadcast-then-slice instead of
scatter_to_tp_ranksbecause the_FullDimRMSNormmodule may reside on CPU / meta device where NCCL scatter is not available.- hf_to_megatron(
- hf_weights: torch.Tensor,
- megatron_module: torch.nn.Module,
- megatron_to_hf(
- megatron_weights: Optional[torch.Tensor],
- megatron_module: Optional[torch.nn.Module],
- class bridge.models.minimax_m2.minimax_m2_bridge.MiniMaxM2Bridge#
Bases:
megatron.bridge.models.conversion.model_bridge.MegatronModelBridgeMegatron Bridge for MiniMax-M2 MoE Causal LM.
MiniMax-M2 is a sparse MoE model (256 experts, top-8 routing with sigmoid scoring and expert bias correction). Use the native transformers >= 5.0 implementation (no
trust_remote_coderequired).On-disk checkpoint format (both the HF hub checkpoint and models saved with
save_pretrained) uses the legacyblock_sparse_moekey prefix with per-expertw1(gate),w3(up), andw2(down) weight tensors. The in-memory model API usesmlp/gate_up_proj/down_projbut serialization reverts to the legacy layout.QK normalization: MiniMax-M2 applies full-dimension RMSNorm to Q/K (weight shape = num_heads * head_dim) before splitting into heads. Megatron’s built-in QK norm is per-head (weight shape = head_dim). This bridge uses a custom layer spec (
minimax_m2_layer_spec) withFullDimQNorm/FullDimKNormthat normalizes over the full partition dimension. With TP > 1 the sum-of-squares is all-reduced across TP ranks so the RMS denominator matches the single-GPU case.Known limitations: - MTP (Multi-Token Prediction) modules are not mapped.
.. rubric:: Example
from megatron.bridge import AutoBridge bridge = AutoBridge.from_hf_pretrained(“MiniMaxAI/MiniMax-M2”) provider = bridge.to_megatron_provider()
- provider_bridge(hf_pretrained)#
Convert HuggingFace MiniMax-M2 config to GPTModelProvider.
- maybe_modify_loaded_hf_weight(
- hf_param: str | dict[str, str],
- hf_state_dict: collections.abc.Mapping[str, torch.Tensor],
Load HF weights with FP8 block-wise dequantization when needed.
MiniMax-M2 stores linear weights as float8_e4m3fn with per-block scale factors in
<key>_scale_invtensors (128x128 blocks).
- _load_and_dequant(
- key: str,
- hf_state_dict: collections.abc.Mapping[str, torch.Tensor],
- mapping_registry() megatron.bridge.models.conversion.mapping_registry.MegatronMappingRegistry#