bridge.models.mimo_v2_flash.mimo_v2_flash_bridge#
Megatron Bridge for MiMo-V2-Flash (Hybrid Attention + Fine-Grained MoE).
MiMo-V2-Flash from Xiaomi features:
Hybrid attention: alternating full and sliding-window attention layers
Fine-grained MoE: 256 small experts with top-8 routing
Asymmetric head dims: head_dim=192 for Q/K, v_head_dim=128 for V
Partial rotary: only 33.4% of head dims get RoPE
Dual rope bases: 5M (full attn) and 10K (SWA)
Module Contents#
Classes#
QKV mapping for MiMo-V2-Flash asymmetric head dims. |
|
Megatron Bridge for MiMo-V2-Flash. |
Functions#
Block-wise FP8 dequantization: out = fp8_val * scale_inv. |
API#
- class bridge.models.mimo_v2_flash.mimo_v2_flash_bridge.MiMoV2FlashQKVMapping#
Bases:
megatron.bridge.models.conversion.param_mapping.QKVMappingQKV mapping for MiMo-V2-Flash asymmetric head dims.
MiMo-V2-Flash uses head_dim=192 for Q/K but v_head_dim=128 for V. Standard merge_qkv_weights uses kv_channels (192) for all three, causing a shape mismatch for V. We temporarily patch v_head_dim onto the config before merging.
- hf_to_megatron(
- hf_weights: Dict[str, torch.Tensor],
- megatron_module: torch.nn.Module,
- megatron_to_hf(
- megatron_weights: Optional[torch.Tensor],
- megatron_module: Optional[torch.nn.Module],
Gather QKV shards and split into Q, K, V.
- bridge.models.mimo_v2_flash.mimo_v2_flash_bridge._dequant_fp8_blockwise(
- weight: torch.Tensor,
- scale_inv: torch.Tensor,
Block-wise FP8 dequantization: out = fp8_val * scale_inv.
- class bridge.models.mimo_v2_flash.mimo_v2_flash_bridge.MiMoV2FlashBridge#
Bases:
megatron.bridge.models.conversion.model_bridge.MegatronModelBridgeMegatron Bridge for MiMo-V2-Flash.
- provider_bridge(
- hf_pretrained: megatron.bridge.models.hf_pretrained.causal_lm.PreTrainedCausalLM,
Convert HuggingFace MiMo-V2-Flash config to MiMoV2FlashModelProvider.
- classmethod megatron_to_hf_config(
- provider: megatron.bridge.models.mimo_v2_flash.mimo_v2_flash_provider.MiMoV2FlashModelProvider,
Convert Megatron provider config to HuggingFace config dict.
- mapping_registry() megatron.bridge.models.conversion.mapping_registry.MegatronMappingRegistry#
- maybe_modify_loaded_hf_weight(
- hf_param: str | dict[str, str],
- hf_state_dict: Mapping[str, torch.Tensor],
Dequantize FP8 weights during import.
- _load_and_dequant(
- key: str,
- hf_state_dict: Mapping[str, torch.Tensor],