nemo_automodel.components.models.mistral3_vlm.state_dict_adapter

View as Markdown

State-dict adapter for the Mistral-3.5 128B (dawn-ridge) FP8 VLM.

Plugs into the standard nemo_automodel checkpoint flow (nemo_automodel/components/checkpoint/checkpointing.py ~lines 510, 556) and handles FP8 dequantization during load/save:

  • The checkpoint’s language_model Linear weights are stored as per-tensor FP8 with a scalar weight_scale_inv sibling (and an unused activation_scale sibling). The adapter pairs each weight with its scale on load, dequantizes to bf16 (w_bf16 = w_fp8.to(bf16) * scale), and drops the scale keys. Vision tower + multi_modal_projector + lm_head are BF16 on disk and pass through unchanged.

The live HF VLM module keeps the body under model.* while the checkpoint stores text weights under language_model.model.* and top-level VLM components as vision_tower.* / multi_modal_projector.*. The LM head is also nested on disk as language_model.lm_head.weight while the runtime module exposes it as lm_head.weight.

Structurally modelled after nemo_automodel/components/models/deepseek_v3/state_dict_adapter.py.

Module Contents

Classes

NameDescription
Mistral3FP8StateDictAdapterFP8 dequant adapter for the Mistral-3.5 128B dawn-ridge VLM.

Functions

NameDescription
_dequantize_from_fp8Dequantize a single FP8 weight using its per-tensor scalar scale.
_identity-
_is_fp8_weight_keyReturn True iff model_key names an FP8 Linear weight.
_vlm_full_hf_to_nativeMap checkpoint VLM names back to runtime parameter names.
_vlm_full_native_to_hfMap runtime VLM parameter names to checkpoint names.

Data

_HF_LM_HEAD_KEY

_MODEL_LM_HEAD_KEY

_NON_QUANTIZED_SUFFIXES

logger

API

class nemo_automodel.components.models.mistral3_vlm.state_dict_adapter.Mistral3FP8StateDictAdapter(
native_to_hf: typing.Callable[[str], str] = _identity,
hf_to_native: typing.Callable[[str], str] = _identity,
layout_name: str = 'vlm_full',
not_fp8_prefixes: tuple[str, ...] = ()
)

Bases: StateDictAdapter

FP8 dequant adapter for the Mistral-3.5 128B dawn-ridge VLM.

Keys round-trip identity (HF state_dict and on-disk keys match for the full VLM). Only language_model layer weights are FP8; vision_tower, multi_modal_projector, and lm_head are BF16 and pass through unchanged via the not_fp8_prefixes / _NON_QUANTIZED_SUFFIXES filters.

_not_fp8_prefixes
= tuple(not_fp8_prefixes)
nemo_automodel.components.models.mistral3_vlm.state_dict_adapter.Mistral3FP8StateDictAdapter.convert_single_tensor_to_hf(
fqn: str,
tensor: typing.Any,
kwargs = {}
) -> list[tuple[str, typing.Any]]

Per-tensor model → HF used by Checkpointer.save_model.

nemo_automodel.components.models.mistral3_vlm.state_dict_adapter.Mistral3FP8StateDictAdapter.for_vlm_full() -> 'Mistral3FP8StateDictAdapter'
classmethod

Full-VLM path for Mistral3ForConditionalGeneration checkpoints.

The runtime module keeps VLM body modules under model.* but the checkpoint stores text weights under language_model.model.* and non-text component names at top level. The LM head has one extra quirk: the model exposes it at the top level (lm_head.weight) while the checkpoint nests it (language_model.lm_head.weight). Tied checkpoints (Ministral-3) never serialize the head, so the head translation is a harmless no-op there; untied checkpoints (Devstral-24B) rely on it to find the head during the DCP load.

Only the language_model layer weights are FP8; vision / mm_projector / lm_head are BF16 on disk and must be passed through without a scale_inv placeholder — otherwise DCP would fail trying to fetch a non-existent _scale_inv key.

nemo_automodel.components.models.mistral3_vlm.state_dict_adapter.Mistral3FP8StateDictAdapter.from_hf(
hf_state_dict: dict[str, typing.Any],
device_mesh: typing.Optional['DeviceMesh'] = None,
kwargs = {}
) -> dict[str, typing.Any]

Convert an HF-format (possibly FP8) state dict to model-native format.

nemo_automodel.components.models.mistral3_vlm.state_dict_adapter.Mistral3FP8StateDictAdapter.to_hf(
state_dict: dict[str, typing.Any],
exclude_key_regex: typing.Optional[str] = None,
quantization: bool = False,
kwargs = {}
) -> dict[str, typing.Any]

Convert a model-native state dict to HF (on-disk) layout.

When quantization=True the weight placeholder is also cast to torch.float8_e4m3fn so the DCP storage reader fetches FP8 bytes verbatim from safetensors (a bf16 target would silently cast-on-read and lose the scale multiply — see deepseek_v3/state_dict_adapter.py:220). A scalar _scale_inv placeholder is also emitted so DCP pulls it alongside the weight.

nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._dequantize_from_fp8(
weight_fp8: torch.Tensor,
scale_inv: torch.Tensor,
target_dtype: torch.dtype = torch.bfloat16
) -> torch.Tensor

Dequantize a single FP8 weight using its per-tensor scalar scale.

The dawn-ridge 128B checkpoint uses per-tensor quantization (weight_block_size=None), so scale_inv is a 0-d scalar and dequant collapses to a simple multiply. The per-block formula (transformers.integrations.finegrained_fp8.Fp8Dequantize.convert, finegrained_fp8.py:867-906) is not needed here.

nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._identity(
k: str
) -> str
nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._is_fp8_weight_key(
model_key: str,
not_fp8_prefixes: tuple[str, ...] = ()
) -> bool

Return True iff model_key names an FP8 Linear weight.

nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._vlm_full_hf_to_native(
hf_key: str
) -> str

Map checkpoint VLM names back to runtime parameter names.

nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._vlm_full_native_to_hf(
model_key: str
) -> str

Map runtime VLM parameter names to checkpoint names.

nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._HF_LM_HEAD_KEY = 'language_model.lm_head.weight'
nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._MODEL_LM_HEAD_KEY = 'lm_head.weight'
nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._NON_QUANTIZED_SUFFIXES = ('embed_tokens.weight', 'lm_head.weight', 'input_layernorm.weight', 'post_attent...
nemo_automodel.components.models.mistral3_vlm.state_dict_adapter.logger = logging.getLogger(__name__)