nemo_automodel.components.models.mimo_v2_flash.state_dict_adapter#
Module Contents#
Classes#
Convert MiMo-V2-Flash HF checkpoints to Automodel’s grouped MoE layout. |
Functions#
Data#
API#
- nemo_automodel.components.models.mimo_v2_flash.state_dict_adapter.logger#
‘getLogger(…)’
- nemo_automodel.components.models.mimo_v2_flash.state_dict_adapter.NON_QUANTIZED_KEY_PATTERNS#
[‘input_layernorm.weight’, ‘post_attention_layernorm.weight’, ‘norm.weight’, ‘lm_head.weight’, ‘embe…
- nemo_automodel.components.models.mimo_v2_flash.state_dict_adapter._should_quantize_key(key: str) bool#
- class nemo_automodel.components.models.mimo_v2_flash.state_dict_adapter.MiMoV2FlashStateDictAdapter(
- config: Any,
- moe_config: nemo_automodel.components.moe.config.MoEConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- dtype: torch.dtype = torch.bfloat16,
Bases:
nemo_automodel.components.moe.state_dict_mixin.MoESplitExpertsStateDictMixin,nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapterConvert MiMo-V2-Flash HF checkpoints to Automodel’s grouped MoE layout.
HF stores routed experts as split per-expert projections:
mlp.experts.{E}.{gate,up,down}_proj.weight. Automodel groups those intogate_and_up_projsanddown_projsso EP can shard experts without materializing every expert on every rank.Initialization
- from_hf(
- hf_state_dict: dict[str, Any],
- device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
- **kwargs,
- to_hf(
- state_dict: dict[str, Any],
- exclude_key_regex: str | None = None,
- quantization: bool = False,
- **kwargs,
Convert Automodel state_dict to the HF MiMo-V2-Flash layout.
Note: The
quantizationparameter is accepted for interface compatibility but is ignored. MiMo-V2-Flash is distributed as an FP8 HF checkpoint, so this adapter always emits FP8 weights plus_scale_invcompanions for keys that match_should_quantize_key, regardless of the caller’s preference.
- convert_single_tensor_to_hf(
- fqn: str,
- tensor: Any,
- **kwargs,
- _create_scale_inv_for_hf_key(
- key: str,
- weight: torch.Tensor,
- _dequantize(
- state_dict: dict[str, Any],