nemo_automodel.components.models.mistral4.state_dict_adapter#

Module Contents#

Classes#

Mistral4StateDictAdapter

State dict adapter for Mistral 4 text-only (CausalLM).

Mistral4MultimodalStateDictAdapter

State dict adapter for the full multimodal Mistral 4 (ForConditionalGeneration).

Functions#

_should_quantize_key

Check if a key should be quantized based on its name.

_dequantize_state_dict

Dequantize FP8 weights in-place. Handles both per-tensor and block-wise formats.

_convert_aggregated_experts

Convert aggregated expert weights from HF format to native format.

_inject_missing_gate_bias

Inject zero e_score_correction_bias for MoE layers that lack it.

Data#

API#

nemo_automodel.components.models.mistral4.state_dict_adapter.logger#

‘getLogger(…)’

nemo_automodel.components.models.mistral4.state_dict_adapter._HF_PREFIX#

‘language_model.’

nemo_automodel.components.models.mistral4.state_dict_adapter._NON_QUANTIZED_PATTERNS#

[‘input_layernorm.weight’, ‘post_attention_layernorm.weight’, ‘norm.weight’, ‘lm_head.weight’, ‘embe…

nemo_automodel.components.models.mistral4.state_dict_adapter._should_quantize_key(key: str) bool#

Check if a key should be quantized based on its name.

Handles both standard keys (.weight) and Mistral4 aggregated expert keys (.gate_up_proj, *.down_proj) which don’t have a .weight suffix. Only text model weights are FP8; vision tower, projector, and lm_head are not.

nemo_automodel.components.models.mistral4.state_dict_adapter._dequantize_state_dict(
state_dict: dict[str, Any],
dtype: torch.dtype,
) dict[str, Any]#

Dequantize FP8 weights in-place. Handles both per-tensor and block-wise formats.

Mistral 4 HF checkpoint has two FP8 patterns:

  • Standard weights: *.weight + *.weight_scale_inv (attention, shared experts)

  • Expert weights: mlp.experts.gate_up_proj + mlp.experts.gate_up_proj_scale_inv (no .weight suffix)

nemo_automodel.components.models.mistral4.state_dict_adapter._convert_aggregated_experts(
state_dict: dict[str, Any],
) dict[str, Any]#

Convert aggregated expert weights from HF format to native format.

HF format (aggregated 3D tensors): mlp.experts.gate_up_proj [128, 2*moe_inter_dim, hidden_size] mlp.experts.down_proj [128, hidden_size, moe_inter_dim]

Native format: mlp.experts.gate_and_up_projs [128, hidden_size, 2*moe_inter_dim] mlp.experts.down_projs [128, moe_inter_dim, hidden_size]

nemo_automodel.components.models.mistral4.state_dict_adapter._inject_missing_gate_bias(
state_dict: dict[str, Any],
n_routed_experts: int,
) dict[str, Any]#

Inject zero e_score_correction_bias for MoE layers that lack it.

Some checkpoints (e.g. vv4) don’t include the gate bias — it starts at zero and is learned during training. The model always expects the key, so we inject torch.zeros(n_routed_experts) for any layer that has a gate weight but no bias.

class nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4StateDictAdapter(
config,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype = torch.float32,
)#

Bases: nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapter

State dict adapter for Mistral 4 text-only (CausalLM).

Handles:

  1. Stripping language_model. prefix from HF keys

  2. FP8 dequantization (per-tensor and block-wise)

  3. Aggregated expert weight conversion (3D tensors → native format)

  4. Removing activation scale keys

Initialization

_strip_prefix(
state_dict: dict[str, Any],
) dict[str, Any]#

Strip language_model. prefix from all keys.

from_hf(
hf_state_dict: dict[str, Any],
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
**kwargs,
) dict[str, Any]#
to_hf(
state_dict: dict[str, Any],
exclude_key_regex: Optional[str] = None,
quantization: bool = False,
**kwargs,
) dict[str, Any]#
convert_single_tensor_to_hf(
fqn: str,
tensor: Any,
**kwargs,
) list[tuple[str, Any]]#
class nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4MultimodalStateDictAdapter(
config,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype = torch.float32,
)#

Bases: nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapter

State dict adapter for the full multimodal Mistral 4 (ForConditionalGeneration).

Checkpoint key prefixes → native model key prefixes: language_model.model.X → model.language_model.X (text backbone) language_model.lm_head.X → lm_head.X (LM head) vision_tower.X → model.vision_tower.X (Pixtral) multi_modal_projector.X → model.multi_modal_projector.X

FP8 dequantization is applied only to text-model weights (vision/projector are not quantized). Expert weights are converted from aggregated 3D format to native format.

Initialization

_remap_keys_from_hf(
state_dict: dict[str, Any],
) dict[str, Any]#

Remap checkpoint keys to native model keys.

_remap_keys_to_hf(key: str) str#

Remap a single native key back to checkpoint format.

from_hf(
hf_state_dict: dict[str, Any],
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
**kwargs,
) dict[str, Any]#

Convert HF checkpoint to native format.

Pipeline:

  1. Remap checkpoint keys to native model keys

  2. Dequantize FP8 weights (text model only; vision/projector are not quantized)

  3. Convert aggregated expert weights to native format

to_hf(
state_dict: dict[str, Any],
exclude_key_regex: Optional[str] = None,
quantization: bool = False,
**kwargs,
) dict[str, Any]#
convert_single_tensor_to_hf(
fqn: str,
tensor: Any,
**kwargs,
) list[tuple[str, Any]]#