bridge.models.mimo_v2_flash.mimo_v2_flash_bridge#

Megatron Bridge for MiMo-V2-Flash (Hybrid Attention + Fine-Grained MoE).

MiMo-V2-Flash from Xiaomi features:

  • Hybrid attention: alternating full and sliding-window attention layers

  • Fine-grained MoE: 256 small experts with top-8 routing

  • Asymmetric head dims: head_dim=192 for Q/K, v_head_dim=128 for V

  • Partial rotary: only 33.4% of head dims get RoPE

  • Dual rope bases: 5M (full attn) and 10K (SWA)

Module Contents#

Classes#

MiMoV2FlashQKVMapping

QKV mapping for MiMo-V2-Flash asymmetric head dims.

MiMoV2FlashBridge

Megatron Bridge for MiMo-V2-Flash.

Functions#

_dequant_fp8_blockwise

Block-wise FP8 dequantization: out = fp8_val * scale_inv.

API#

class bridge.models.mimo_v2_flash.mimo_v2_flash_bridge.MiMoV2FlashQKVMapping#

Bases: megatron.bridge.models.conversion.param_mapping.QKVMapping

QKV mapping for MiMo-V2-Flash asymmetric head dims.

MiMo-V2-Flash uses head_dim=192 for Q/K but v_head_dim=128 for V. Standard merge_qkv_weights uses kv_channels (192) for all three, causing a shape mismatch for V. We temporarily patch v_head_dim onto the config before merging.

hf_to_megatron(
hf_weights: Dict[str, torch.Tensor],
megatron_module: torch.nn.Module,
)#
megatron_to_hf(
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[torch.nn.Module],
) Dict[str, torch.Tensor]#

Gather QKV shards and split into Q, K, V.

bridge.models.mimo_v2_flash.mimo_v2_flash_bridge._dequant_fp8_blockwise(
weight: torch.Tensor,
scale_inv: torch.Tensor,
) torch.Tensor#

Block-wise FP8 dequantization: out = fp8_val * scale_inv.

class bridge.models.mimo_v2_flash.mimo_v2_flash_bridge.MiMoV2FlashBridge#

Bases: megatron.bridge.models.conversion.model_bridge.MegatronModelBridge

Megatron Bridge for MiMo-V2-Flash.

provider_bridge(
hf_pretrained: megatron.bridge.models.hf_pretrained.causal_lm.PreTrainedCausalLM,
) megatron.bridge.models.mimo_v2_flash.mimo_v2_flash_provider.MiMoV2FlashModelProvider#

Convert HuggingFace MiMo-V2-Flash config to MiMoV2FlashModelProvider.

classmethod megatron_to_hf_config(
provider: megatron.bridge.models.mimo_v2_flash.mimo_v2_flash_provider.MiMoV2FlashModelProvider,
) dict#

Convert Megatron provider config to HuggingFace config dict.

mapping_registry() megatron.bridge.models.conversion.mapping_registry.MegatronMappingRegistry#
maybe_modify_loaded_hf_weight(
hf_param: str | dict[str, str],
hf_state_dict: Mapping[str, torch.Tensor],
) torch.Tensor#

Dequantize FP8 weights during import.

_load_and_dequant(
key: str,
hf_state_dict: Mapping[str, torch.Tensor],
) torch.Tensor#