nemo_automodel.components.models.mistral4.model

View as Markdown

Module Contents

Classes

NameDescription
Mistral3ForConditionalGenerationFull multimodal Mistral 4: Pixtral vision + projector + Mistral4 MLA/MoE text backbone.
Mistral3ModelVLM wrapper composing vision tower + projector + Mistral4 text backend.
Mistral4BlockBlock using Mistral4MLA instead of MLA.
Mistral4ForCausalLM-
Mistral4MLAMLA with Llama 4 attention scaling for Mistral 4.
Mistral4Model-
Mistral4TextModelBackendBackend-aware Mistral4 text model for use inside the multimodal wrapper.

Functions

NameDescription
_build_moe_configBuild MoEConfig from a Mistral4 text config.
_get_llama_4_attn_scalePosition-dependent attention scaling for long-context extrapolation (Llama 4 / Mistral 4).

Data

ModelClass

_HF_MISTRAL3_AVAILABLE

API

class nemo_automodel.components.models.mistral4.model.Mistral3ForConditionalGeneration(
config,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)

Bases: HFCheckpointingMixin, Module, MoEFSDPSyncMixin

Full multimodal Mistral 4: Pixtral vision + projector + Mistral4 MLA/MoE text backbone.

Follows KimiK25VLForConditionalGeneration pattern: inherits from nn.Module (not HF PreTrainedModel) to avoid FSDP conflicts.

image_token_index
= getattr(config, 'image_token_index', 10)
model
moe_config
= self.model.language_model.moe_config
pad_token_id
= getattr(text_config, 'pad_token_id', -1) or -1
state_dict_adapter
vocab_size
= text_config.vocab_size
nemo_automodel.components.models.mistral4.model.Mistral3ForConditionalGeneration.forward(
input_ids: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
pixel_values: torch.Tensor | None = None,
image_sizes: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.mistral4.model.Mistral3ForConditionalGeneration.from_config(
config,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)
classmethod
nemo_automodel.components.models.mistral4.model.Mistral3ForConditionalGeneration.from_pretrained(
pretrained_model_name_or_path: str,
model_args = (),
kwargs = {}
)
classmethod
nemo_automodel.components.models.mistral4.model.Mistral3ForConditionalGeneration.get_input_embeddings()
nemo_automodel.components.models.mistral4.model.Mistral3ForConditionalGeneration.get_output_embeddings()
nemo_automodel.components.models.mistral4.model.Mistral3ForConditionalGeneration.initialize_weights(
buffer_device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16
) -> None
nemo_automodel.components.models.mistral4.model.Mistral3ForConditionalGeneration.set_input_embeddings(
value
)
nemo_automodel.components.models.mistral4.model.Mistral3ForConditionalGeneration.set_output_embeddings(
new_embeddings
)
nemo_automodel.components.models.mistral4.model.Mistral3ForConditionalGeneration.supports_config(
config
) -> bool
classmethod

Only handle configs whose text backbone is Mistral4 (MoE + MLA).

nemo_automodel.components.models.mistral4.model.Mistral3ForConditionalGeneration.update_moe_gate_bias() -> None
class nemo_automodel.components.models.mistral4.model.Mistral3Model(
config,
vision_tower,
multi_modal_projector,
language_model
)

Bases: Module

VLM wrapper composing vision tower + projector + Mistral4 text backend.

Follows KimiK25VLModel pattern: plain nn.Module (not HF PreTrainedModel) to avoid FSDP conflicts from PreTrainedModel’s module registration hooks. Vision processing logic is replicated from HF Mistral3Model.

nemo_automodel.components.models.mistral4.model.Mistral3Model._get_image_features(
pixel_values,
image_sizes,
vision_feature_layer = -1
)

Encode images through vision tower + projector (from HF Mistral3Model).

nemo_automodel.components.models.mistral4.model.Mistral3Model.forward(
input_ids = None,
pixel_values = None,
attention_mask = None,
position_ids = None,
past_key_values = None,
inputs_embeds = None,
image_sizes = None,
padding_mask = None,
kwargs = {}
)
nemo_automodel.components.models.mistral4.model.Mistral3Model.get_input_embeddings()
class nemo_automodel.components.models.mistral4.model.Mistral4Block(
layer_idx,
config,
moe_config,
backend
)

Bases: Block

Block using Mistral4MLA instead of MLA.

self_attn
= Mistral4MLA(config, backend)
class nemo_automodel.components.models.mistral4.model.Mistral4ForCausalLM(
config,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)

Bases: HFCheckpointingMixin, Module, MoEFSDPSyncMixin

backend
= backend or BackendConfig()
lm_head
model
state_dict_adapter
nemo_automodel.components.models.mistral4.model.Mistral4ForCausalLM.forward(
input_ids: torch.Tensor,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
logits_to_keep: typing.Union[int, torch.Tensor] = 0,
output_hidden_states: typing.Optional[bool] = None,
attn_kwargs: typing.Any = {}
) -> transformers.modeling_outputs.CausalLMOutputWithPast
nemo_automodel.components.models.mistral4.model.Mistral4ForCausalLM.from_config(
config,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)
classmethod
nemo_automodel.components.models.mistral4.model.Mistral4ForCausalLM.from_pretrained(
pretrained_model_name_or_path: str,
model_args = (),
kwargs = {}
)
classmethod
nemo_automodel.components.models.mistral4.model.Mistral4ForCausalLM.get_input_embeddings()
nemo_automodel.components.models.mistral4.model.Mistral4ForCausalLM.get_output_embeddings()
nemo_automodel.components.models.mistral4.model.Mistral4ForCausalLM.initialize_weights(
buffer_device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16
) -> None
nemo_automodel.components.models.mistral4.model.Mistral4ForCausalLM.set_input_embeddings(
value
)
nemo_automodel.components.models.mistral4.model.Mistral4ForCausalLM.set_output_embeddings(
new_embeddings
)
nemo_automodel.components.models.mistral4.model.Mistral4ForCausalLM.update_moe_gate_bias() -> None
class nemo_automodel.components.models.mistral4.model.Mistral4MLA(
config,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: MLA

MLA with Llama 4 attention scaling for Mistral 4.

Compared to DeepSeek V3 MLA, adds position-dependent scaling to q_pe after RoPE (llama_4_scaling_beta). RoPE itself uses the same complex-number approach as DSV3.

llama_4_orig_max_pos
llama_4_scaling_beta
nemo_automodel.components.models.mistral4.model.Mistral4MLA.forward(
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
)
class nemo_automodel.components.models.mistral4.model.Mistral4Model(
config,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
moe_overrides: dict | None = None
)

Bases: Module

embed_tokens
layers
= torch.nn.ModuleDict()
max_seq_len
= config.max_position_embeddings
moe_config
norm
nemo_automodel.components.models.mistral4.model.Mistral4Model.forward(
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.mistral4.model.Mistral4Model.init_weights(
buffer_device: torch.device | None = None
) -> None
nemo_automodel.components.models.mistral4.model.Mistral4Model.update_moe_gate_bias() -> None
class nemo_automodel.components.models.mistral4.model.Mistral4TextModelBackend(
config,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
moe_overrides: dict | None = None
)

Bases: Module

Backend-aware Mistral4 text model for use inside the multimodal wrapper.

Wraps Mistral4Model in self.model (like KimiK25VLLanguageModelBackend wraps DeepseekV3Model). This ensures embed_tokens/layers/norm are accessed via @property aliases rather than as direct nn.Module children, which avoids FSDP double-root-init when the parallelizer wraps both embed_tokens and this module.

lm_head
model
moe_config
= self.model.moe_config
nemo_automodel.components.models.mistral4.model.Mistral4TextModelBackend.forward(
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
past_key_values = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
cache_position: torch.Tensor | None = None,
kwargs: typing.Any = {}
) -> transformers.modeling_outputs.BaseModelOutputWithPast
nemo_automodel.components.models.mistral4.model.Mistral4TextModelBackend.get_input_embeddings()
nemo_automodel.components.models.mistral4.model.Mistral4TextModelBackend.init_weights(
buffer_device: torch.device | None = None
)
nemo_automodel.components.models.mistral4.model.Mistral4TextModelBackend.set_input_embeddings(
value
)
nemo_automodel.components.models.mistral4.model._build_moe_config(
config,
moe_overrides: dict | None = None
) -> nemo_automodel.components.moe.config.MoEConfig

Build MoEConfig from a Mistral4 text config.

nemo_automodel.components.models.mistral4.model._get_llama_4_attn_scale(
position_ids: torch.Tensor,
beta: float,
max_position_embeddings: int
) -> torch.Tensor

Position-dependent attention scaling for long-context extrapolation (Llama 4 / Mistral 4).

nemo_automodel.components.models.mistral4.model.ModelClass = Mistral4ForCausalLM
nemo_automodel.components.models.mistral4.model._HF_MISTRAL3_AVAILABLE = True