nemo_automodel.components.models.mistral3_vlm.state_dict_adapter#

State-dict adapter for the Mistral-3.5 128B (dawn-ridge) FP8 VLM.

Plugs into the standard nemo_automodel checkpoint flow (nemo_automodel/components/checkpoint/checkpointing.py ~lines 510, 556) and handles FP8 dequantization during load/save:

  • The checkpoint’s language_model Linear weights are stored as per-tensor FP8 with a scalar weight_scale_inv sibling (and an unused activation_scale sibling). The adapter pairs each weight with its scale on load, dequantizes to bf16 (w_bf16 = w_fp8.to(bf16) * scale), and drops the scale keys. Vision tower + multi_modal_projector + lm_head are BF16 on disk and pass through unchanged.

Keys round-trip identically: the HF VLM state_dict and the on-disk keys both use model.language_model.*, model.vision_tower.*, model.multi_modal_projector.*, lm_head.weight — no rename needed.

Structurally modelled after nemo_automodel/components/models/deepseek_v3/state_dict_adapter.py.

Module Contents#

Classes#

Mistral3FP8StateDictAdapter

FP8 dequant adapter for the Mistral-3.5 128B dawn-ridge VLM.

Functions#

_is_fp8_weight_key

Return True iff model_key names an FP8 Linear weight.

_dequantize_from_fp8

Dequantize a single FP8 weight using its per-tensor scalar scale.

_identity

Data#

API#

nemo_automodel.components.models.mistral3_vlm.state_dict_adapter.logger#

‘getLogger(…)’

nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._NON_QUANTIZED_SUFFIXES#

(‘embed_tokens.weight’, ‘lm_head.weight’, ‘input_layernorm.weight’, ‘post_attention_layernorm.weight…

nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._is_fp8_weight_key(
model_key: str,
not_fp8_prefixes: tuple[str, ...] = (),
) bool#

Return True iff model_key names an FP8 Linear weight.

nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._dequantize_from_fp8(
weight_fp8: torch.Tensor,
scale_inv: torch.Tensor,
target_dtype: torch.dtype = torch.bfloat16,
) torch.Tensor#

Dequantize a single FP8 weight using its per-tensor scalar scale.

The dawn-ridge 128B checkpoint uses per-tensor quantization (weight_block_size=None), so scale_inv is a 0-d scalar and dequant collapses to a simple multiply. The per-block formula (transformers.integrations.finegrained_fp8.Fp8Dequantize.convert, finegrained_fp8.py:867-906) is not needed here.

nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._identity(k: str) str#
class nemo_automodel.components.models.mistral3_vlm.state_dict_adapter.Mistral3FP8StateDictAdapter(
*,
native_to_hf: Callable[[str], str] = _identity,
hf_to_native: Callable[[str], str] = _identity,
layout_name: str = 'vlm_full',
not_fp8_prefixes: tuple[str, ...] = (),
)#

Bases: nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapter

FP8 dequant adapter for the Mistral-3.5 128B dawn-ridge VLM.

Keys round-trip identity (HF state_dict and on-disk keys match for the full VLM). Only language_model layer weights are FP8; vision_tower, multi_modal_projector, and lm_head are BF16 and pass through unchanged via the not_fp8_prefixes / _NON_QUANTIZED_SUFFIXES filters.

Initialization

classmethod for_vlm_full() nemo_automodel.components.models.mistral3_vlm.state_dict_adapter.Mistral3FP8StateDictAdapter#

Full-VLM path for Mistral3ForConditionalGeneration checkpoints.

Keys round-trip identically between HF’s VLM state_dict() and disk for the dawn-ridge-128B checkpoint (both use model.language_model.*, model.vision_tower.*, model.multi_modal_projector.*, lm_head.weight). Only the language_model layer weights are FP8; vision / mm_projector / lm_head are BF16 on disk and must be passed through without a scale_inv placeholder — otherwise DCP would fail trying to fetch a non-existent _scale_inv key.

to_hf(
state_dict: dict[str, Any],
exclude_key_regex: Optional[str] = None,
quantization: bool = False,
**kwargs,
) dict[str, Any]#

Convert a model-native state dict to HF (on-disk) layout.

When quantization=True the weight placeholder is also cast to torch.float8_e4m3fn so the DCP storage reader fetches FP8 bytes verbatim from safetensors (a bf16 target would silently cast-on-read and lose the scale multiply — see deepseek_v3/state_dict_adapter.py:220). A scalar _scale_inv placeholder is also emitted so DCP pulls it alongside the weight.

from_hf(
hf_state_dict: dict[str, Any],
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
**kwargs,
) dict[str, Any]#

Convert an HF-format (possibly FP8) state dict to model-native format.

convert_single_tensor_to_hf(
fqn: str,
tensor: Any,
**kwargs,
) list[tuple[str, Any]]#

Per-tensor model → HF used by Checkpointer.save_model.