nemo_automodel.components.models.mistral3_vlm.state_dict_adapter
nemo_automodel.components.models.mistral3_vlm.state_dict_adapter
State-dict adapter for the Mistral-3.5 128B (dawn-ridge) FP8 VLM.
Plugs into the standard nemo_automodel checkpoint flow (nemo_automodel/components/checkpoint/checkpointing.py ~lines 510, 556) and handles FP8 dequantization during load/save:
- The checkpoint’s language_model Linear weights are stored as per-tensor
FP8 with a scalar
weight_scale_invsibling (and an unusedactivation_scalesibling). The adapter pairs each weight with its scale on load, dequantizes to bf16 (w_bf16 = w_fp8.to(bf16) * scale), and drops the scale keys. Vision tower + multi_modal_projector + lm_head are BF16 on disk and pass through unchanged.
The live HF VLM module keeps the body under model.* while the checkpoint
stores text weights under language_model.model.* and top-level VLM
components as vision_tower.* / multi_modal_projector.*. The LM head is
also nested on disk as language_model.lm_head.weight while the runtime
module exposes it as lm_head.weight.
Structurally modelled after
nemo_automodel/components/models/deepseek_v3/state_dict_adapter.py.
Module Contents
Classes
Functions
Data
API
Bases: StateDictAdapter
FP8 dequant adapter for the Mistral-3.5 128B dawn-ridge VLM.
Keys round-trip identity (HF state_dict and on-disk keys match for the
full VLM). Only language_model layer weights are FP8; vision_tower,
multi_modal_projector, and lm_head are BF16 and pass through unchanged
via the not_fp8_prefixes / _NON_QUANTIZED_SUFFIXES filters.
Per-tensor model → HF used by Checkpointer.save_model.
Full-VLM path for Mistral3ForConditionalGeneration checkpoints.
The runtime module keeps VLM body modules under model.* but the
checkpoint stores text weights under language_model.model.* and
non-text component names at top level. The LM head has one extra
quirk: the model exposes it at the top level (lm_head.weight) while
the checkpoint nests it (language_model.lm_head.weight).
Tied checkpoints (Ministral-3) never serialize the head, so the head
translation is a harmless no-op there; untied checkpoints (Devstral-24B)
rely on it to find the head during the DCP load.
Only the language_model layer weights are FP8; vision / mm_projector /
lm_head are BF16 on disk and must be passed through without a scale_inv
placeholder — otherwise DCP would fail trying to fetch a non-existent
_scale_inv key.
Convert an HF-format (possibly FP8) state dict to model-native format.
Convert a model-native state dict to HF (on-disk) layout.
When quantization=True the weight placeholder is also cast to
torch.float8_e4m3fn so the DCP storage reader fetches FP8 bytes
verbatim from safetensors (a bf16 target would silently cast-on-read
and lose the scale multiply — see deepseek_v3/state_dict_adapter.py:220).
A scalar _scale_inv placeholder is also emitted so DCP pulls it
alongside the weight.
Dequantize a single FP8 weight using its per-tensor scalar scale.
The dawn-ridge 128B checkpoint uses per-tensor quantization
(weight_block_size=None), so scale_inv is a 0-d scalar and
dequant collapses to a simple multiply. The per-block formula
(transformers.integrations.finegrained_fp8.Fp8Dequantize.convert,
finegrained_fp8.py:867-906) is not needed here.
Return True iff model_key names an FP8 Linear weight.
Map checkpoint VLM names back to runtime parameter names.
Map runtime VLM parameter names to checkpoint names.