nemo_automodel.components.models.mistral3_vlm.model
nemo_automodel.components.models.mistral3_vlm.model
FP8-native Mistral3 VLM (dawn-ridge / Mistral-3.5 128B).
Custom wrapper around HF’s Mistral3ForConditionalGeneration that:
- Inherits the full VLM architecture (vision_tower + multi_modal_projector
- Ministral3 language_model) so image inputs flow through Pixtral.
- Attaches
Mistral3FP8StateDictAdapter.for_vlm_full()so FP8 dequant runs inside the standard DCP load path (avoids HF’s FineGrainedFP8 loader, which materializes the full BF16 model on every rank pre-PP-split and OOMs on 80 GB H100). - Attaches a one-shot forward pre-hook on every rotary submodule to
recompute
inv_freqon first call — needed because HF’s Ministral3 / Pixtral rotaries computeinv_freqin__init__, so meta-init +to_emptyleaves the buffer at uninitialized memory.
Module Contents
Classes
Functions
Data
API
Bases: _HFMistral3ForConditionalGeneration
Full-VLM (vision + text) FP8 loader for Mistral3ForConditionalGeneration.
Used when the user instantiates through
NeMoAutoModelForImageTextToText.from_pretrained on an FP8-native
Mistral3 VLM checkpoint (e.g. dawn-ridge-128B).
Forward pass with memory-efficient fused cross-entropy (cut-CE) support.
Overrides HF’s Mistral3ForConditionalGeneration.forward so the
train_ft recipe can enable FusedLinearCrossEntropy. The recipe
only does so when (a) forward exposes a logits_to_keep parameter
and (b) calling the model returns an output that carries the FINAL hidden
states (full sequence) while logits cover only the kept positions.
HF’s stock forward gates hidden_states on a per-call
output_hidden_states kwarg (which the recipe does not pass) and emits
the full per-layer tuple. Here we instead resolve output_hidden_states
from the text sub-config and surface the inner model’s last_hidden_state
(the single [B, S, H] tensor fed to lm_head) directly, which is
what get_final_hidden_states consumes.
Parameters:
Input token IDs [B, S].
Optional image pixel values for the vision tower.
Optional attention mask.
Optional position indices.
Optional cached key/values.
Optional pre-computed embeddings.
Optional labels for loss computation.
Whether to use KV caching.
Number of final logits to compute (0=all, N=last N tokens).
Optional image sizes for the vision tower.
Whether to surface the final hidden states on the
output (defaults to the text sub-config’s output_hidden_states).
Additional arguments forwarded to the base model.
Returns: Mistral3CausalLMOutputWithPast
class:~transformers.models.mistral3.modeling_mistral3.Mistral3CausalLMOutputWithPast
Claim FP8-native Mistral3 VLM configs.
Matches Mistral3Config (outer VLM) with a ministral3 text backbone
and quantization_config.quant_method == 'fp8'.
One-shot forward pre-hook that recomputes this rotary module’s own
inv_freq on first call.
Attached per-rotary rather than on the outer VLM so it fires correctly
under pipeline parallelism, where the outer model’s forward is never
called directly — the PP schedule dispatches each stage’s sub-modules
individually, and rotary modules run inside every attention layer.
Background: HF’s Ministral3 / Pixtral rotary classes initialise
inv_freq (and related attributes) in their __init__. Under
accelerate.init_empty_weights that becomes a meta tensor, and the
subsequent to_empty(device) call leaves it uninitialised device
memory. Neither class exposes rope_init_fn as an attribute, so the
generic _reinit_non_persistent_buffers helper doesn’t match. We
recover correctness by re-running the module’s own __init__ on the
target device outside the init_empty_weights context — both Ministral3
(YaRN) and Pixtral (2D patch positions) produce the right values this
way, since we defer to the class’s authoritative init logic.