nemo_automodel.components.models.mistral4.state_dict_adapter
nemo_automodel.components.models.mistral4.state_dict_adapter
Module Contents
Classes
Functions
Data
API
Bases: StateDictAdapter
State dict adapter for the full multimodal Mistral 4 (ForConditionalGeneration).
Checkpoint key prefixes → native model key prefixes:
language_model.model.X → model.language_model.X (text backbone)
language_model.lm_head.X → lm_head.X (LM head)
vision_tower.X → model.vision_tower.X (Pixtral)
multi_modal_projector.X → model.multi_modal_projector.X
FP8 dequantization is applied only to text-model weights (vision/projector are not quantized). Expert weights are converted from aggregated 3D format to native format.
Remap checkpoint keys to native model keys.
Remap a single native key back to checkpoint format.
Convert HF checkpoint to native format.
Pipeline:
- Remap checkpoint keys to native model keys
- Dequantize FP8 weights (text model only; vision/projector are not quantized)
- Convert aggregated expert weights to native format
Bases: StateDictAdapter
State dict adapter for Mistral 4 text-only (CausalLM).
Handles:
- Stripping
language_model.prefix from HF keys - FP8 dequantization (per-tensor and block-wise)
- Aggregated expert weight conversion (3D tensors → native format)
- Removing activation scale keys
Strip language_model. prefix from all keys.
Convert aggregated expert weights from HF format to native format.
HF format (aggregated 3D tensors): mlp.experts.gate_up_proj [128, 2*moe_inter_dim, hidden_size] mlp.experts.down_proj [128, hidden_size, moe_inter_dim]
Dequantize FP8 weights in-place. Handles both per-tensor and block-wise formats.
Mistral 4 HF checkpoint has two FP8 patterns:
- Standard weights:
*.weight+*.weight_scale_inv(attention, shared experts) - Expert weights:
mlp.experts.gate_up_proj+mlp.experts.gate_up_proj_scale_inv(no .weight suffix)
Inject zero e_score_correction_bias for MoE layers that lack it.
Some checkpoints (e.g. vv4) don’t include the gate bias — it starts at zero
and is learned during training. The model always expects the key, so we
inject torch.zeros(n_routed_experts) for any layer that has a gate weight
but no bias.
Check if a key should be quantized based on its name.
Handles both standard keys (.weight) and Mistral4 aggregated expert keys (.gate_up_proj, *.down_proj) which don’t have a .weight suffix. Only text model weights are FP8; vision tower, projector, and lm_head are not.