nemo_automodel.components.models.mistral4.state_dict_adapter

View as Markdown

Module Contents

Classes

NameDescription
Mistral4MultimodalStateDictAdapterState dict adapter for the full multimodal Mistral 4 (ForConditionalGeneration).
Mistral4StateDictAdapterState dict adapter for Mistral 4 text-only (CausalLM).

Functions

NameDescription
_convert_aggregated_expertsConvert aggregated expert weights from HF format to native format.
_dequantize_state_dictDequantize FP8 weights in-place. Handles both per-tensor and block-wise formats.
_inject_missing_gate_biasInject zero e_score_correction_bias for MoE layers that lack it.
_should_quantize_keyCheck if a key should be quantized based on its name.

Data

_HF_PREFIX

_NON_QUANTIZED_PATTERNS

logger

API

class nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4MultimodalStateDictAdapter(
config,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype = torch.float32
)

Bases: StateDictAdapter

State dict adapter for the full multimodal Mistral 4 (ForConditionalGeneration).

Checkpoint key prefixes → native model key prefixes: language_model.model.Xmodel.language_model.X (text backbone) language_model.lm_head.Xlm_head.X (LM head) vision_tower.Xmodel.vision_tower.X (Pixtral) multi_modal_projector.Xmodel.multi_modal_projector.X

FP8 dequantization is applied only to text-model weights (vision/projector are not quantized). Expert weights are converted from aggregated 3D format to native format.

nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4MultimodalStateDictAdapter._remap_keys_from_hf(
state_dict: dict[str, typing.Any]
) -> dict[str, typing.Any]

Remap checkpoint keys to native model keys.

nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4MultimodalStateDictAdapter._remap_keys_to_hf(
key: str
) -> str

Remap a single native key back to checkpoint format.

nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4MultimodalStateDictAdapter.convert_single_tensor_to_hf(
fqn: str,
tensor: typing.Any,
kwargs = {}
) -> list[tuple[str, typing.Any]]
nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4MultimodalStateDictAdapter.from_hf(
hf_state_dict: dict[str, typing.Any],
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None,
kwargs = {}
) -> dict[str, typing.Any]

Convert HF checkpoint to native format.

Pipeline:

  1. Remap checkpoint keys to native model keys
  2. Dequantize FP8 weights (text model only; vision/projector are not quantized)
  3. Convert aggregated expert weights to native format
nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4MultimodalStateDictAdapter.to_hf(
state_dict: dict[str, typing.Any],
exclude_key_regex: typing.Optional[str] = None,
quantization: bool = False,
kwargs = {}
) -> dict[str, typing.Any]
class nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4StateDictAdapter(
config,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype = torch.float32
)

Bases: StateDictAdapter

State dict adapter for Mistral 4 text-only (CausalLM).

Handles:

  1. Stripping language_model. prefix from HF keys
  2. FP8 dequantization (per-tensor and block-wise)
  3. Aggregated expert weight conversion (3D tensors → native format)
  4. Removing activation scale keys
nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4StateDictAdapter._strip_prefix(
state_dict: dict[str, typing.Any]
) -> dict[str, typing.Any]

Strip language_model. prefix from all keys.

nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4StateDictAdapter.convert_single_tensor_to_hf(
fqn: str,
tensor: typing.Any,
kwargs = {}
) -> list[tuple[str, typing.Any]]
nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4StateDictAdapter.from_hf(
hf_state_dict: dict[str, typing.Any],
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None,
kwargs = {}
) -> dict[str, typing.Any]
nemo_automodel.components.models.mistral4.state_dict_adapter.Mistral4StateDictAdapter.to_hf(
state_dict: dict[str, typing.Any],
exclude_key_regex: typing.Optional[str] = None,
quantization: bool = False,
kwargs = {}
) -> dict[str, typing.Any]
nemo_automodel.components.models.mistral4.state_dict_adapter._convert_aggregated_experts(
state_dict: dict[str, typing.Any]
) -> dict[str, typing.Any]

Convert aggregated expert weights from HF format to native format.

HF format (aggregated 3D tensors): mlp.experts.gate_up_proj [128, 2*moe_inter_dim, hidden_size] mlp.experts.down_proj [128, hidden_size, moe_inter_dim]

nemo_automodel.components.models.mistral4.state_dict_adapter._dequantize_state_dict(
state_dict: dict[str, typing.Any],
dtype: torch.dtype
) -> dict[str, typing.Any]

Dequantize FP8 weights in-place. Handles both per-tensor and block-wise formats.

Mistral 4 HF checkpoint has two FP8 patterns:

  • Standard weights: *.weight + *.weight_scale_inv (attention, shared experts)
  • Expert weights: mlp.experts.gate_up_proj + mlp.experts.gate_up_proj_scale_inv (no .weight suffix)
nemo_automodel.components.models.mistral4.state_dict_adapter._inject_missing_gate_bias(
state_dict: dict[str, typing.Any],
n_routed_experts: int
) -> dict[str, typing.Any]

Inject zero e_score_correction_bias for MoE layers that lack it.

Some checkpoints (e.g. vv4) don’t include the gate bias — it starts at zero and is learned during training. The model always expects the key, so we inject torch.zeros(n_routed_experts) for any layer that has a gate weight but no bias.

nemo_automodel.components.models.mistral4.state_dict_adapter._should_quantize_key(
key: str
) -> bool

Check if a key should be quantized based on its name.

Handles both standard keys (.weight) and Mistral4 aggregated expert keys (.gate_up_proj, *.down_proj) which don’t have a .weight suffix. Only text model weights are FP8; vision tower, projector, and lm_head are not.

nemo_automodel.components.models.mistral4.state_dict_adapter._HF_PREFIX = 'language_model.'
nemo_automodel.components.models.mistral4.state_dict_adapter._NON_QUANTIZED_PATTERNS = ['input_layernorm.weight', 'post_attention_layernorm.weight', 'norm.weight', 'lm...
nemo_automodel.components.models.mistral4.state_dict_adapter.logger = logging.getLogger(__name__)