nemo_automodel.components.models.mistral4.state_dict_adapter#
Module Contents#
Classes#
State dict adapter for Mistral 4 text-only (CausalLM). |
|
State dict adapter for the full multimodal Mistral 4 (ForConditionalGeneration). |
Functions#
Check if a key should be quantized based on its name. |
|
Dequantize FP8 weights in-place. Handles both per-tensor and block-wise formats. |
|
Convert aggregated expert weights from HF format to native format. |
|
Inject zero |
Data#
API#
- nemo_automodel.components.models.mistral4.state_dict_adapter.logger#
‘getLogger(…)’
- nemo_automodel.components.models.mistral4.state_dict_adapter._HF_PREFIX#
‘language_model.’
- nemo_automodel.components.models.mistral4.state_dict_adapter._NON_QUANTIZED_PATTERNS#
[‘input_layernorm.weight’, ‘post_attention_layernorm.weight’, ‘norm.weight’, ‘lm_head.weight’, ‘embe…
- nemo_automodel.components.models.mistral4.state_dict_adapter._should_quantize_key(key: str) bool#
Check if a key should be quantized based on its name.
Handles both standard keys (.weight) and Mistral4 aggregated expert keys (.gate_up_proj, *.down_proj) which don’t have a .weight suffix. Only text model weights are FP8; vision tower, projector, and lm_head are not.
- nemo_automodel.components.models.mistral4.state_dict_adapter._dequantize_state_dict(
- state_dict: dict[str, Any],
- dtype: torch.dtype,
Dequantize FP8 weights in-place. Handles both per-tensor and block-wise formats.
Mistral 4 HF checkpoint has two FP8 patterns:
Standard weights:
*.weight+*.weight_scale_inv(attention, shared experts)Expert weights:
mlp.experts.gate_up_proj+mlp.experts.gate_up_proj_scale_inv(no .weight suffix)
- nemo_automodel.components.models.mistral4.state_dict_adapter._convert_aggregated_experts(
- state_dict: dict[str, Any],
Convert aggregated expert weights from HF format to native format.
HF format (aggregated 3D tensors): mlp.experts.gate_up_proj [128, 2*moe_inter_dim, hidden_size] mlp.experts.down_proj [128, hidden_size, moe_inter_dim]
Native format: mlp.experts.gate_and_up_projs [128, hidden_size, 2*moe_inter_dim] mlp.experts.down_projs [128, moe_inter_dim, hidden_size]
- nemo_automodel.components.models.mistral4.state_dict_adapter._inject_missing_gate_bias(
- state_dict: dict[str, Any],
- n_routed_experts: int,
Inject zero
e_score_correction_biasfor MoE layers that lack it.Some checkpoints (e.g. vv4) don’t include the gate bias — it starts at zero and is learned during training. The model always expects the key, so we inject
torch.zeros(n_routed_experts)for any layer that has a gate weight but no bias.
- class nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4StateDictAdapter(
- config,
- moe_config: nemo_automodel.components.moe.config.MoEConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- dtype: torch.dtype = torch.float32,
Bases:
nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapterState dict adapter for Mistral 4 text-only (CausalLM).
Handles:
Stripping
language_model.prefix from HF keysFP8 dequantization (per-tensor and block-wise)
Aggregated expert weight conversion (3D tensors → native format)
Removing activation scale keys
Initialization
- _strip_prefix(
- state_dict: dict[str, Any],
Strip
language_model.prefix from all keys.
- from_hf(
- hf_state_dict: dict[str, Any],
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- **kwargs,
- to_hf(
- state_dict: dict[str, Any],
- exclude_key_regex: Optional[str] = None,
- quantization: bool = False,
- **kwargs,
- convert_single_tensor_to_hf(
- fqn: str,
- tensor: Any,
- **kwargs,
- class nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4MultimodalStateDictAdapter(
- config,
- moe_config: nemo_automodel.components.moe.config.MoEConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- dtype: torch.dtype = torch.float32,
Bases:
nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapterState dict adapter for the full multimodal Mistral 4 (ForConditionalGeneration).
Checkpoint key prefixes → native model key prefixes:
language_model.model.X→model.language_model.X(text backbone)language_model.lm_head.X→lm_head.X(LM head)vision_tower.X→model.vision_tower.X(Pixtral)multi_modal_projector.X→model.multi_modal_projector.XFP8 dequantization is applied only to text-model weights (vision/projector are not quantized). Expert weights are converted from aggregated 3D format to native format.
Initialization
- _remap_keys_from_hf(
- state_dict: dict[str, Any],
Remap checkpoint keys to native model keys.
- _remap_keys_to_hf(key: str) str#
Remap a single native key back to checkpoint format.
- from_hf(
- hf_state_dict: dict[str, Any],
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- **kwargs,
Convert HF checkpoint to native format.
Pipeline:
Remap checkpoint keys to native model keys
Dequantize FP8 weights (text model only; vision/projector are not quantized)
Convert aggregated expert weights to native format
- to_hf(
- state_dict: dict[str, Any],
- exclude_key_regex: Optional[str] = None,
- quantization: bool = False,
- **kwargs,
- convert_single_tensor_to_hf(
- fqn: str,
- tensor: Any,
- **kwargs,