nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter

View as Markdown

State-dict adapter for the MiniMax M3 (text) backbone.

Converts between the released HF checkpoint layout and the native AutoModel layout:

  • block_sparse_moe.{gate,e_score_correction_bias} -> mlp.gate.*
  • block_sparse_moe.experts.{e}.{w1,w3,w2} -> grouped mlp.experts.* (gate/up/down) via MoESplitExpertsStateDictMixin
  • block_sparse_moe.shared_experts.* -> shared_experts.* (a sibling of mlp on the decoder block)
  • dense (non-MoE) layers keep mlp.{gate,up,down}_proj.* unchanged

MXFP8 weights (FP8 e4m3 + *_scale_inv stored as e8m0/uint8, block [1,32] along the input dim) are dequantized to dtype on load (Q2 decision: train in BF16). Stage 1 drops the sparse-attention index branch (self_attn.index_*) and MTP (mtp.*) tensors; those are wired in Stages 2 and 4.

Module Contents

Classes

NameDescription
MiniMaxM3StateDictAdapterConvert MiniMax M3 HF checkpoints to/from the native grouped-expert format.
MiniMaxM3VLStateDictAdapterVLM adapter: splits the M3 VL checkpoint into text / vision / projector parts.

Functions

NameDescription
_dequantize_mxfp8_local-
_should_quantize_mxfp8_keyTrue for HF-format weight keys stored as MXFP8 in the checkpoint.
_slice_mxfp8_scale_for_dtensorSlice a global scale_inv to a DTensor weight’s local shard.
create_mxfp8_scale_invLoad-time placeholder scale_inv (e8m0/uint8, GLOBAL shape [out, ceil(in/block)]).
dequantize_mxfp8Dequantize an MXFP8 weight (FP8 e4m3 + e8m0/uint8 block scales) to dtype.

Data

MXFP8_BLOCK_SIZE

_MXFP8_QUANT_KEY_RE

_MXFP8_SCALE_INV_IDENTITY

API

class nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3StateDictAdapter(
config: typing.Any,
moe_config: nemo_automodel.components.moe.layers.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype = torch.bfloat16
)

Bases: MoESplitExpertsStateDictMixin, StateDictAdapter

Convert MiniMax M3 HF checkpoints to/from the native grouped-expert format.

_expert_path_segment
str
_mtp_enabled
bool
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3StateDictAdapter._dequantize(
state_dict: dict[str, typing.Any]
) -> dict[str, typing.Any]
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3StateDictAdapter._hf_key_to_native(
key: str
) -> str
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3StateDictAdapter._mtp_from_hf(
mtp_keys: dict[str, typing.Any],
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None
) -> dict[str, typing.Any]

Convert MTP tensors: the transformer_layer reuses the full text from_hf (as a fake 1-layer model, so expert-merge / index / dequant all apply); the enorm/hnorm/eh_proj/final_layernorm fusion tensors pass through (eh_proj is FP8).

nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3StateDictAdapter._mtp_tensor_to_hf(
fqn: str,
tensor: typing.Any,
kwargs = {}
) -> list[tuple[str, typing.Any]]
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3StateDictAdapter._native_key_to_hf(
key: str
) -> str
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3StateDictAdapter.convert_single_tensor_to_hf(
fqn: str,
tensor: typing.Any,
kwargs = {}
) -> list[tuple[str, typing.Any]]
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3StateDictAdapter.from_hf(
hf_state_dict: dict[str, typing.Any],
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None,
kwargs = {}
) -> dict[str, typing.Any]

Convert an HF checkpoint to native format (operates in-place to limit peak memory).

nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3StateDictAdapter.to_hf(
state_dict: dict[str, typing.Any],
exclude_key_regex: typing.Optional[str] = None,
quantization: bool = False,
kwargs = {}
) -> dict[str, typing.Any]
class nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3VLStateDictAdapter(
config: typing.Any,
moe_config: nemo_automodel.components.moe.layers.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype = torch.bfloat16
)

Bases: StateDictAdapter

VLM adapter: splits the M3 VL checkpoint into text / vision / projector parts.

The released checkpoint stores the language backbone under language_model.model.* / language_model.lm_head and the vision side under vision_tower.vision_model.* with the projector / patch-merger at top level (multi_modal_projector.* / patch_merge_mlp.*). The native VLM keeps the text model at model.* / lm_head and nests the projector / merger under vision_tower.*. Text tensors are delegated to :class:MiniMaxM3StateDictAdapter (block_sparse_moe -> mlp, index branch, MXFP8 dequant, grouped experts); vision tensors are BF16 and pass through.

text_adapter
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3VLStateDictAdapter._map_non_text_from_hf(
key: str
) -> str | None
staticmethod
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3VLStateDictAdapter._map_non_text_to_hf(
key: str
) -> str
staticmethod
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3VLStateDictAdapter.convert_single_tensor_to_hf(
fqn: str,
tensor: typing.Any,
kwargs = {}
) -> list[tuple[str, typing.Any]]
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3VLStateDictAdapter.from_hf(
hf_state_dict: dict[str, typing.Any],
device_mesh = None,
kwargs = {}
) -> dict[str, typing.Any]
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MiniMaxM3VLStateDictAdapter.to_hf(
state_dict: dict[str, typing.Any],
exclude_key_regex = None,
quantization: bool = False,
kwargs = {}
)
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter._dequantize_mxfp8_local(
w_local: torch.Tensor,
scale_local: torch.Tensor,
block_size: int,
dtype
) -> torch.Tensor
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter._should_quantize_mxfp8_key(
key: str
) -> bool

True for HF-format weight keys stored as MXFP8 in the checkpoint.

nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter._slice_mxfp8_scale_for_dtensor(
scale_inv: torch.Tensor,
weight_dtensor: torch.Tensor,
weight_local: torch.Tensor,
block_size: int
) -> torch.Tensor

Slice a global scale_inv to a DTensor weight’s local shard.

MXFP8 block is [1, block_size]: dim 0 (out) is full-resolution (block 1, so a row range maps 1:1) and dim 1 (in) is grouped by block_size. Custom MoE is always tp=1, so sharding is on dim 0 (FSDP / ep_shard); dim 1 handled for safety.

nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.create_mxfp8_scale_inv(
weight: torch.Tensor,
block_size: int = MXFP8_BLOCK_SIZE
) -> torch.Tensor

Load-time placeholder scale_inv (e8m0/uint8, GLOBAL shape [out, ceil(in/block)]).

Emitted by to_hf(quantization=True) so the DCP planner requests the checkpoint’s *_scale_inv tensors; the values here are overwritten by the load. Kept a regular (non-DTensor) tensor with global shape — the per-shard slice happens in dequantize_mxfp8 (mirrors deepseek_v3).

nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.dequantize_mxfp8(
weight: torch.Tensor,
scale_inv: torch.Tensor,
block_size: int = MXFP8_BLOCK_SIZE,
dtype: torch.dtype = torch.bfloat16
) -> torch.Tensor

Dequantize an MXFP8 weight (FP8 e4m3 + e8m0/uint8 block scales) to dtype.

weight is FP8 e4m3 [out, in]; scale_inv holds e8m0 (uint8) exponents [out, ceil(in/block_size)] with dequant scale for input-block b = 2 ** (scale_inv[:, b] - 127) (MX e8m0; confirmed vs sglang). Handles DTensor weights: the local shard is dequantized against the matching slice of a global scale_inv and rewrapped with the weight’s placements.

nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter.MXFP8_BLOCK_SIZE = 32
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter._MXFP8_QUANT_KEY_RE = re.compile('\\.layers\\.\\d+\\.(?:self_attn\\.[qkvo]_proj|mlp\\.(?:gate|up|down)...
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter._MXFP8_SCALE_INV_IDENTITY = 127