nemo_automodel.components.models.mistral3.model

View as Markdown

Module Contents

Classes

Functions

NameDescription
_get_llama_4_attn_scale-
_register_ministral3_with_transformersRegister Ministral3Config and models with transformers Auto classes.
apply_rotary_pos_emb-
eager_attention_forward-
repeat_kv-
rotate_half-

Data

ModelClass

logger

API

class nemo_automodel.components.models.mistral3.model.GradientCheckpointingLayer()

Bases: Module

nemo_automodel.components.models.mistral3.model.GradientCheckpointingLayer.forward(
args = (),
kwargs = {}
)
class nemo_automodel.components.models.mistral3.model.Ministral3Attention(
config: nemo_automodel.components.models.mistral3.model.Ministral3Config,
layer_idx: int
)

Bases: Module

attention_dropout
= config.attention_dropout
head_dim
k_proj
num_key_value_groups
o_proj
q_proj
scaling
= self.head_dim ** -0.5
v_proj
nemo_automodel.components.models.mistral3.model.Ministral3Attention.forward(
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: typing.Optional[torch.Tensor],
past_key_values: typing.Optional[transformers.cache_utils.Cache] = None,
cache_position: typing.Optional[torch.LongTensor] = None,
kwargs: transformers.processing_utils.Unpack[transformers.modeling_flash_attention_utils.FlashAttentionKwargs] = {}
) -> tuple[torch.Tensor, typing.Optional[torch.Tensor]]
class nemo_automodel.components.models.mistral3.model.Ministral3CausalLMOutputWithPast()
Dataclass

Bases: CausalLMOutputWithPast

class nemo_automodel.components.models.mistral3.model.Ministral3Config(
vocab_size: typing.Optional[int] = 131072,
hidden_size: typing.Optional[int] = 4096,
intermediate_size: typing.Optional[int] = 14336,
num_hidden_layers: typing.Optional[int] = 34,
num_attention_heads: typing.Optional[int] = 32,
num_key_value_heads: typing.Optional[int] = 8,
head_dim: typing.Optional[int] = 128,
hidden_act: typing.Optional[str] = 'silu',
max_position_embeddings: typing.Optional[int] = 262144,
initializer_range: typing.Optional[float] = 0.02,
rms_norm_eps: typing.Optional[float] = 1e-05,
use_cache: typing.Optional[bool] = True,
pad_token_id: typing.Optional[int] = 11,
bos_token_id: typing.Optional[int] = 1,
eos_token_id: typing.Optional[int] = 2,
tie_word_embeddings: typing.Optional[bool] = False,
rope_parameters: typing.Optional[dict] = None,
sliding_window: typing.Optional[int] = None,
attention_dropout: typing.Optional[float] = 0.0,
kwargs = {}
)

Bases: PretrainedConfig

Configuration for Ministral3 text decoder.

base_model_pp_plan
base_model_tp_plan
head_dim
keys_to_ignore_at_inference
= ['past_key_values']
model_type
= 'ministral3'
rope_scaling
rope_theta
= self.rope_parameters.get('rope_theta', 1000000.0)
class nemo_automodel.components.models.mistral3.model.Ministral3DecoderLayer(
config: nemo_automodel.components.models.mistral3.model.Ministral3Config,
layer_idx: int
)

Bases: GradientCheckpointingLayer

hidden_size
= config.hidden_size
input_layernorm
mlp
= Ministral3MLP(config)
post_attention_layernorm
self_attn
nemo_automodel.components.models.mistral3.model.Ministral3DecoderLayer.forward(
hidden_states: torch.Tensor,
attention_mask: typing.Optional[torch.Tensor] = None,
position_ids: typing.Optional[torch.LongTensor] = None,
past_key_values: typing.Optional[transformers.cache_utils.Cache] = None,
use_cache: typing.Optional[bool] = False,
cache_position: typing.Optional[torch.LongTensor] = None,
position_embeddings: typing.Optional[tuple[torch.Tensor, torch.Tensor]] = None,
kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs] = {}
) -> torch.Tensor
class nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM(
config: nemo_automodel.components.models.mistral3.model.Ministral3Config
)

Bases: HFCheckpointingMixin, Ministral3PreTrainedModel, GenerationMixin

_pp_plan
= {'lm_head': (['hidden_states'], ['logits'])}
_tied_weights_keys
= {'lm_head.weight': 'model.embed_tokens.weight'}
_tp_plan
= {'lm_head': 'colwise_rep'}
lm_head
model
= Ministral3Model(config)
vocab_size
= config.vocab_size
nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM.forward(
input_ids: typing.Optional[torch.LongTensor] = None,
attention_mask: typing.Optional[torch.Tensor] = None,
position_ids: typing.Optional[torch.LongTensor] = None,
past_key_values: typing.Optional[transformers.cache_utils.Cache] = None,
inputs_embeds: typing.Optional[torch.FloatTensor] = None,
labels: typing.Optional[torch.LongTensor] = None,
use_cache: typing.Optional[bool] = None,
cache_position: typing.Optional[torch.LongTensor] = None,
logits_to_keep: typing.Union[int, torch.Tensor] = 0,
output_hidden_states: typing.Optional[bool] = None,
kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs] = {}
) -> transformers.modeling_outputs.CausalLMOutputWithPast
nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM.get_input_embeddings()
nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM.get_output_embeddings()
nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM.set_input_embeddings(
value
)
nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM.set_output_embeddings(
new_embeddings
)
class nemo_automodel.components.models.mistral3.model.Ministral3MLP(
config: nemo_automodel.components.models.mistral3.model.Ministral3Config
)

Bases: Module

act_fn
= ACT2FN[config.hidden_act]
down_proj
gate_proj
hidden_size
= config.hidden_size
intermediate_size
= config.intermediate_size
up_proj
nemo_automodel.components.models.mistral3.model.Ministral3MLP.forward(
x
)
class nemo_automodel.components.models.mistral3.model.Ministral3Model(
config: nemo_automodel.components.models.mistral3.model.Ministral3Config
)

Bases: Ministral3PreTrainedModel

embed_tokens
layers
norm
padding_idx
= config.pad_token_id
rotary_emb
= Ministral3RotaryEmbedding(config=config)
vocab_size
= config.vocab_size
nemo_automodel.components.models.mistral3.model.Ministral3Model.forward(
input_ids: typing.Optional[torch.LongTensor] = None,
attention_mask: typing.Optional[torch.Tensor] = None,
position_ids: typing.Optional[torch.LongTensor] = None,
past_key_values: typing.Optional[transformers.cache_utils.Cache] = None,
inputs_embeds: typing.Optional[torch.FloatTensor] = None,
use_cache: typing.Optional[bool] = None,
cache_position: typing.Optional[torch.LongTensor] = None,
kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs] = {}
) -> transformers.modeling_outputs.BaseModelOutputWithPast
class nemo_automodel.components.models.mistral3.model.Ministral3ModelOutputWithPast(
image_hidden_states: typing.Optional[torch.FloatTensor] = None
)
Dataclass

Bases: BaseModelOutputWithPast

image_hidden_states
Optional[FloatTensor] = None
class nemo_automodel.components.models.mistral3.model.Ministral3PreTrainedModel()

Bases: PreTrainedModel

_can_record_outputs
= {}
_no_split_modules
= ['Ministral3DecoderLayer']
_skip_keys_device_placement
= ['past_key_values']
base_model_prefix
= 'model'
config
Ministral3Config
class nemo_automodel.components.models.mistral3.model.Ministral3RMSNorm(
hidden_size,
eps = 1e-06
)

Bases: Module

weight
= nn.Parameter(torch.ones(hidden_size))
nemo_automodel.components.models.mistral3.model.Ministral3RMSNorm.forward(
hidden_states
)
class nemo_automodel.components.models.mistral3.model.Ministral3RotaryEmbedding(
config: nemo_automodel.components.models.mistral3.model.Ministral3Config,
device = None
)

Bases: Module

inv_freq
Tensor
max_seq_len_cached
= config.max_position_embeddings
original_max_seq_len
= config.max_position_embeddings
rope_type
nemo_automodel.components.models.mistral3.model.Ministral3RotaryEmbedding.compute_default_rope_parameters(
config: typing.Optional[nemo_automodel.components.models.mistral3.model.Ministral3Config] = None,
device: typing.Optional[torch.device] = None,
seq_len: typing.Optional[int] = None
) -> tuple[torch.Tensor, float]
staticmethod
nemo_automodel.components.models.mistral3.model.Ministral3RotaryEmbedding.forward(
x,
position_ids
)
nemo_automodel.components.models.mistral3.model._get_llama_4_attn_scale(
positions_ids: torch.Tensor,
beta: float,
max_position_embeddings: int
) -> torch.Tensor
nemo_automodel.components.models.mistral3.model._register_ministral3_with_transformers()

Register Ministral3Config and models with transformers Auto classes.

This uses the official transformers registration API. Registration is idempotent (re-registering the same config/model is a no-op in recent transformers versions).

nemo_automodel.components.models.mistral3.model.apply_rotary_pos_emb(
q,
k,
cos,
sin,
position_ids = None,
unsqueeze_dim = 1
)
nemo_automodel.components.models.mistral3.model.eager_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: typing.Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs] = {}
)
nemo_automodel.components.models.mistral3.model.repeat_kv(
hidden_states: torch.Tensor,
n_rep: int
) -> torch.Tensor
nemo_automodel.components.models.mistral3.model.rotate_half(
x
)
nemo_automodel.components.models.mistral3.model.ModelClass = Ministral3ForCausalLM
nemo_automodel.components.models.mistral3.model.logger = logging.get_logger(__name__)