nemo_automodel.components.models.mistral3_vlm.model

View as Markdown

FP8-native Mistral3 VLM (dawn-ridge / Mistral-3.5 128B).

Custom wrapper around HF’s Mistral3ForConditionalGeneration that:

  • Inherits the full VLM architecture (vision_tower + multi_modal_projector
    • Ministral3 language_model) so image inputs flow through Pixtral.
  • Attaches Mistral3FP8StateDictAdapter.for_vlm_full() so FP8 dequant runs inside the standard DCP load path (avoids HF’s FineGrainedFP8 loader, which materializes the full BF16 model on every rank pre-PP-split and OOMs on 80 GB H100).
  • Attaches a one-shot forward pre-hook on every rotary submodule to recompute inv_freq on first call — needed because HF’s Ministral3 / Pixtral rotaries compute inv_freq in __init__, so meta-init + to_empty leaves the buffer at uninitialized memory.

Module Contents

Classes

NameDescription
Mistral3FP8VLMForConditionalGenerationFull-VLM (vision + text) FP8 loader for Mistral3ForConditionalGeneration.

Functions

NameDescription
_rotary_reinit_self_hookOne-shot forward pre-hook that recomputes this rotary module’s own

Data

logger

API

class nemo_automodel.components.models.mistral3_vlm.model.Mistral3FP8VLMForConditionalGeneration(
config: transformers.PretrainedConfig
)

Bases: _HFMistral3ForConditionalGeneration

Full-VLM (vision + text) FP8 loader for Mistral3ForConditionalGeneration.

Used when the user instantiates through NeMoAutoModelForImageTextToText.from_pretrained on an FP8-native Mistral3 VLM checkpoint (e.g. dawn-ridge-128B).

state_dict_adapter
= Mistral3FP8StateDictAdapter.for_vlm_full()
nemo_automodel.components.models.mistral3_vlm.model.Mistral3FP8VLMForConditionalGeneration.forward(
input_ids: typing.Optional[torch.LongTensor] = None,
pixel_values: typing.Optional[torch.FloatTensor] = None,
attention_mask: typing.Optional[torch.Tensor] = None,
position_ids: typing.Optional[torch.LongTensor] = None,
past_key_values = None,
inputs_embeds: typing.Optional[torch.FloatTensor] = None,
labels: typing.Optional[torch.LongTensor] = None,
use_cache: typing.Optional[bool] = None,
logits_to_keep: typing.Union[int, torch.Tensor] = 0,
image_sizes: typing.Optional[torch.Tensor] = None,
output_hidden_states: typing.Optional[bool] = None,
kwargs = {}
) -> transformers.models.mistral3.modeling_mistral3.Mistral3CausalLMOutputWithPast

Forward pass with memory-efficient fused cross-entropy (cut-CE) support.

Overrides HF’s Mistral3ForConditionalGeneration.forward so the train_ft recipe can enable FusedLinearCrossEntropy. The recipe only does so when (a) forward exposes a logits_to_keep parameter and (b) calling the model returns an output that carries the FINAL hidden states (full sequence) while logits cover only the kept positions.

HF’s stock forward gates hidden_states on a per-call output_hidden_states kwarg (which the recipe does not pass) and emits the full per-layer tuple. Here we instead resolve output_hidden_states from the text sub-config and surface the inner model’s last_hidden_state (the single [B, S, H] tensor fed to lm_head) directly, which is what get_final_hidden_states consumes.

Parameters:

input_ids
Optional[torch.LongTensor]Defaults to None

Input token IDs [B, S].

pixel_values
Optional[torch.FloatTensor]Defaults to None

Optional image pixel values for the vision tower.

attention_mask
Optional[torch.Tensor]Defaults to None

Optional attention mask.

position_ids
Optional[torch.LongTensor]Defaults to None

Optional position indices.

past_key_values
Defaults to None

Optional cached key/values.

inputs_embeds
Optional[torch.FloatTensor]Defaults to None

Optional pre-computed embeddings.

labels
Optional[torch.LongTensor]Defaults to None

Optional labels for loss computation.

use_cache
Optional[bool]Defaults to None

Whether to use KV caching.

logits_to_keep
Union[int, torch.Tensor]Defaults to 0

Number of final logits to compute (0=all, N=last N tokens).

image_sizes
Optional[torch.Tensor]Defaults to None

Optional image sizes for the vision tower.

output_hidden_states
Optional[bool]Defaults to None

Whether to surface the final hidden states on the output (defaults to the text sub-config’s output_hidden_states).

**kwargs
Defaults to {}

Additional arguments forwarded to the base model.

Returns: Mistral3CausalLMOutputWithPast

class:~transformers.models.mistral3.modeling_mistral3.Mistral3CausalLMOutputWithPast

nemo_automodel.components.models.mistral3_vlm.model.Mistral3FP8VLMForConditionalGeneration.supports_config(
config: transformers.PretrainedConfig
) -> bool
classmethod

Claim FP8-native Mistral3 VLM configs.

Matches Mistral3Config (outer VLM) with a ministral3 text backbone and quantization_config.quant_method == 'fp8'.

nemo_automodel.components.models.mistral3_vlm.model._rotary_reinit_self_hook(
module,
args,
kwargs
)

One-shot forward pre-hook that recomputes this rotary module’s own inv_freq on first call.

Attached per-rotary rather than on the outer VLM so it fires correctly under pipeline parallelism, where the outer model’s forward is never called directly — the PP schedule dispatches each stage’s sub-modules individually, and rotary modules run inside every attention layer.

Background: HF’s Ministral3 / Pixtral rotary classes initialise inv_freq (and related attributes) in their __init__. Under accelerate.init_empty_weights that becomes a meta tensor, and the subsequent to_empty(device) call leaves it uninitialised device memory. Neither class exposes rope_init_fn as an attribute, so the generic _reinit_non_persistent_buffers helper doesn’t match. We recover correctness by re-running the module’s own __init__ on the target device outside the init_empty_weights context — both Ministral3 (YaRN) and Pixtral (2D patch positions) produce the right values this way, since we defer to the class’s authoritative init logic.

nemo_automodel.components.models.mistral3_vlm.model.logger = logging.getLogger(__name__)