nemo_automodel.components.models.mistral3_vlm.state_dict_adapter#
State-dict adapter for the Mistral-3.5 128B (dawn-ridge) FP8 VLM.
Plugs into the standard nemo_automodel checkpoint flow (nemo_automodel/components/checkpoint/checkpointing.py ~lines 510, 556) and handles FP8 dequantization during load/save:
The checkpoint’s language_model Linear weights are stored as per-tensor FP8 with a scalar
weight_scale_invsibling (and an unusedactivation_scalesibling). The adapter pairs each weight with its scale on load, dequantizes to bf16 (w_bf16 = w_fp8.to(bf16) * scale), and drops the scale keys. Vision tower + multi_modal_projector + lm_head are BF16 on disk and pass through unchanged.
Keys round-trip identically: the HF VLM state_dict and the on-disk keys
both use model.language_model.*, model.vision_tower.*,
model.multi_modal_projector.*, lm_head.weight — no rename needed.
Structurally modelled after
nemo_automodel/components/models/deepseek_v3/state_dict_adapter.py.
Module Contents#
Classes#
FP8 dequant adapter for the Mistral-3.5 128B dawn-ridge VLM. |
Functions#
Return True iff |
|
Dequantize a single FP8 weight using its per-tensor scalar scale. |
|
Data#
API#
- nemo_automodel.components.models.mistral3_vlm.state_dict_adapter.logger#
‘getLogger(…)’
- nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._NON_QUANTIZED_SUFFIXES#
(‘embed_tokens.weight’, ‘lm_head.weight’, ‘input_layernorm.weight’, ‘post_attention_layernorm.weight…
- nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._is_fp8_weight_key(
- model_key: str,
- not_fp8_prefixes: tuple[str, ...] = (),
Return True iff
model_keynames an FP8 Linear weight.
- nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._dequantize_from_fp8(
- weight_fp8: torch.Tensor,
- scale_inv: torch.Tensor,
- target_dtype: torch.dtype = torch.bfloat16,
Dequantize a single FP8 weight using its per-tensor scalar scale.
The dawn-ridge 128B checkpoint uses per-tensor quantization (
weight_block_size=None), soscale_invis a 0-d scalar and dequant collapses to a simple multiply. The per-block formula (transformers.integrations.finegrained_fp8.Fp8Dequantize.convert, finegrained_fp8.py:867-906) is not needed here.
- nemo_automodel.components.models.mistral3_vlm.state_dict_adapter._identity(k: str) str#
- class nemo_automodel.components.models.mistral3_vlm.state_dict_adapter.Mistral3FP8StateDictAdapter(
- *,
- native_to_hf: Callable[[str], str] = _identity,
- hf_to_native: Callable[[str], str] = _identity,
- layout_name: str = 'vlm_full',
- not_fp8_prefixes: tuple[str, ...] = (),
Bases:
nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapterFP8 dequant adapter for the Mistral-3.5 128B dawn-ridge VLM.
Keys round-trip identity (HF state_dict and on-disk keys match for the full VLM). Only language_model layer weights are FP8; vision_tower, multi_modal_projector, and lm_head are BF16 and pass through unchanged via the
not_fp8_prefixes/_NON_QUANTIZED_SUFFIXESfilters.Initialization
- classmethod for_vlm_full() nemo_automodel.components.models.mistral3_vlm.state_dict_adapter.Mistral3FP8StateDictAdapter#
Full-VLM path for Mistral3ForConditionalGeneration checkpoints.
Keys round-trip identically between HF’s VLM
state_dict()and disk for the dawn-ridge-128B checkpoint (both usemodel.language_model.*,model.vision_tower.*,model.multi_modal_projector.*,lm_head.weight). Only the language_model layer weights are FP8; vision / mm_projector / lm_head are BF16 on disk and must be passed through without a scale_inv placeholder — otherwise DCP would fail trying to fetch a non-existent_scale_invkey.
- to_hf(
- state_dict: dict[str, Any],
- exclude_key_regex: Optional[str] = None,
- quantization: bool = False,
- **kwargs,
Convert a model-native state dict to HF (on-disk) layout.
When
quantization=Truethe weight placeholder is also cast totorch.float8_e4m3fnso the DCP storage reader fetches FP8 bytes verbatim from safetensors (a bf16 target would silently cast-on-read and lose the scale multiply — see deepseek_v3/state_dict_adapter.py:220). A scalar_scale_invplaceholder is also emitted so DCP pulls it alongside the weight.
- from_hf(
- hf_state_dict: dict[str, Any],
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- **kwargs,
Convert an HF-format (possibly FP8) state dict to model-native format.
- convert_single_tensor_to_hf(
- fqn: str,
- tensor: Any,
- **kwargs,
Per-tensor model → HF used by
Checkpointer.save_model.