nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter
nemo_automodel.components.models.minimax_m3_vl.state_dict_adapter
State-dict adapter for the MiniMax M3 (text) backbone.
Converts between the released HF checkpoint layout and the native AutoModel layout:
block_sparse_moe.{gate,e_score_correction_bias}->mlp.gate.*block_sparse_moe.experts.{e}.{w1,w3,w2}-> groupedmlp.experts.*(gate/up/down) viaMoESplitExpertsStateDictMixinblock_sparse_moe.shared_experts.*->shared_experts.*(a sibling ofmlpon the decoder block)- dense (non-MoE) layers keep
mlp.{gate,up,down}_proj.*unchanged
MXFP8 weights (FP8 e4m3 + *_scale_inv stored as e8m0/uint8, block [1,32]
along the input dim) are dequantized to dtype on load (Q2 decision: train in
BF16). Stage 1 drops the sparse-attention index branch (self_attn.index_*)
and MTP (mtp.*) tensors; those are wired in Stages 2 and 4.
Module Contents
Classes
Functions
Data
API
Bases: MoESplitExpertsStateDictMixin, StateDictAdapter
Convert MiniMax M3 HF checkpoints to/from the native grouped-expert format.
Convert MTP tensors: the transformer_layer reuses the full text from_hf (as a fake 1-layer model, so expert-merge / index / dequant all apply); the enorm/hnorm/eh_proj/final_layernorm fusion tensors pass through (eh_proj is FP8).
Convert an HF checkpoint to native format (operates in-place to limit peak memory).
Bases: StateDictAdapter
VLM adapter: splits the M3 VL checkpoint into text / vision / projector parts.
The released checkpoint stores the language backbone under
language_model.model.* / language_model.lm_head and the vision side
under vision_tower.vision_model.* with the projector / patch-merger at
top level (multi_modal_projector.* / patch_merge_mlp.*). The native
VLM keeps the text model at model.* / lm_head and nests the projector
/ merger under vision_tower.*. Text tensors are delegated to
:class:MiniMaxM3StateDictAdapter (block_sparse_moe -> mlp, index branch,
MXFP8 dequant, grouped experts); vision tensors are BF16 and pass through.
True for HF-format weight keys stored as MXFP8 in the checkpoint.
Slice a global scale_inv to a DTensor weight’s local shard.
MXFP8 block is [1, block_size]: dim 0 (out) is full-resolution (block 1, so a
row range maps 1:1) and dim 1 (in) is grouped by block_size. Custom MoE is
always tp=1, so sharding is on dim 0 (FSDP / ep_shard); dim 1 handled for safety.
Load-time placeholder scale_inv (e8m0/uint8, GLOBAL shape [out, ceil(in/block)]).
Emitted by to_hf(quantization=True) so the DCP planner requests the
checkpoint’s *_scale_inv tensors; the values here are overwritten by the
load. Kept a regular (non-DTensor) tensor with global shape — the per-shard
slice happens in dequantize_mxfp8 (mirrors deepseek_v3).
Dequantize an MXFP8 weight (FP8 e4m3 + e8m0/uint8 block scales) to dtype.
weight is FP8 e4m3 [out, in]; scale_inv holds e8m0 (uint8)
exponents [out, ceil(in/block_size)] with dequant scale for input-block
b = 2 ** (scale_inv[:, b] - 127) (MX e8m0; confirmed vs sglang). Handles
DTensor weights: the local shard is dequantized against the matching slice of a
global scale_inv and rewrapped with the weight’s placements.