nemo_automodel.components.models.mistral3.model
nemo_automodel.components.models.mistral3.model
Module Contents
Classes
| Name | Description |
|---|---|
GradientCheckpointingLayer | - |
Ministral3Attention | - |
Ministral3CausalLMOutputWithPast | - |
Ministral3Config | Configuration for Ministral3 text decoder. |
Ministral3DecoderLayer | - |
Ministral3ForCausalLM | - |
Ministral3MLP | - |
Ministral3Model | - |
Ministral3ModelOutputWithPast | - |
Ministral3PreTrainedModel | - |
Ministral3RMSNorm | - |
Ministral3RotaryEmbedding | - |
Functions
| Name | Description |
|---|---|
_get_llama_4_attn_scale | - |
_register_ministral3_with_transformers | Register Ministral3Config and models with transformers Auto classes. |
apply_rotary_pos_emb | - |
eager_attention_forward | - |
repeat_kv | - |
rotate_half | - |
Data
API
class nemo_automodel.components.models.mistral3.model.GradientCheckpointingLayer()
Bases: Module
nemo_automodel.components.models.mistral3.model.GradientCheckpointingLayer.forward( args = (), kwargs = {} )
class nemo_automodel.components.models.mistral3.model.Ministral3Attention( config: nemo_automodel.components.models.mistral3.model.Ministral3Config, layer_idx: int )
Bases: Module
attention_dropout
= config.attention_dropout
head_dim
k_proj
num_key_value_groups
o_proj
q_proj
scaling
= self.head_dim ** -0.5
v_proj
nemo_automodel.components.models.mistral3.model.Ministral3Attention.forward( hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: typing.Optional[torch.Tensor], past_key_values: typing.Optional[transformers.cache_utils.Cache] = None, cache_position: typing.Optional[torch.LongTensor] = None, kwargs: transformers.processing_utils.Unpack[transformers.modeling_flash_attention_utils.FlashAttentionKwargs] = {} ) -> tuple[torch.Tensor, typing.Optional[torch.Tensor]]
class nemo_automodel.components.models.mistral3.model.Ministral3CausalLMOutputWithPast()
Dataclass
Bases: CausalLMOutputWithPast
class nemo_automodel.components.models.mistral3.model.Ministral3Config( vocab_size: typing.Optional[int] = 131072, hidden_size: typing.Optional[int] = 4096, intermediate_size: typing.Optional[int] = 14336, num_hidden_layers: typing.Optional[int] = 34, num_attention_heads: typing.Optional[int] = 32, num_key_value_heads: typing.Optional[int] = 8, head_dim: typing.Optional[int] = 128, hidden_act: typing.Optional[str] = 'silu', max_position_embeddings: typing.Optional[int] = 262144, initializer_range: typing.Optional[float] = 0.02, rms_norm_eps: typing.Optional[float] = 1e-05, use_cache: typing.Optional[bool] = True, pad_token_id: typing.Optional[int] = 11, bos_token_id: typing.Optional[int] = 1, eos_token_id: typing.Optional[int] = 2, tie_word_embeddings: typing.Optional[bool] = False, rope_parameters: typing.Optional[dict] = None, sliding_window: typing.Optional[int] = None, attention_dropout: typing.Optional[float] = 0.0, kwargs = {} )
Bases: PretrainedConfig
Configuration for Ministral3 text decoder.
base_model_pp_plan
base_model_tp_plan
head_dim
keys_to_ignore_at_inference
= ['past_key_values']
model_type
= 'ministral3'
rope_scaling
rope_theta
= self.rope_parameters.get('rope_theta', 1000000.0)
class nemo_automodel.components.models.mistral3.model.Ministral3DecoderLayer( config: nemo_automodel.components.models.mistral3.model.Ministral3Config, layer_idx: int )
Bases: GradientCheckpointingLayer
hidden_size
= config.hidden_size
input_layernorm
mlp
= Ministral3MLP(config)
post_attention_layernorm
self_attn
nemo_automodel.components.models.mistral3.model.Ministral3DecoderLayer.forward( hidden_states: torch.Tensor, attention_mask: typing.Optional[torch.Tensor] = None, position_ids: typing.Optional[torch.LongTensor] = None, past_key_values: typing.Optional[transformers.cache_utils.Cache] = None, use_cache: typing.Optional[bool] = False, cache_position: typing.Optional[torch.LongTensor] = None, position_embeddings: typing.Optional[tuple[torch.Tensor, torch.Tensor]] = None, kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs] = {} ) -> torch.Tensor
class nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM( config: nemo_automodel.components.models.mistral3.model.Ministral3Config )
Bases: HFCheckpointingMixin, Ministral3PreTrainedModel, GenerationMixin
_pp_plan
= {'lm_head': (['hidden_states'], ['logits'])}
_tied_weights_keys
= {'lm_head.weight': 'model.embed_tokens.weight'}
_tp_plan
= {'lm_head': 'colwise_rep'}
lm_head
model
= Ministral3Model(config)
vocab_size
= config.vocab_size
nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM.forward( input_ids: typing.Optional[torch.LongTensor] = None, attention_mask: typing.Optional[torch.Tensor] = None, position_ids: typing.Optional[torch.LongTensor] = None, past_key_values: typing.Optional[transformers.cache_utils.Cache] = None, inputs_embeds: typing.Optional[torch.FloatTensor] = None, labels: typing.Optional[torch.LongTensor] = None, use_cache: typing.Optional[bool] = None, cache_position: typing.Optional[torch.LongTensor] = None, logits_to_keep: typing.Union[int, torch.Tensor] = 0, output_hidden_states: typing.Optional[bool] = None, kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs] = {} ) -> transformers.modeling_outputs.CausalLMOutputWithPast
nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM.get_input_embeddings()
nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM.get_output_embeddings()
nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM.set_input_embeddings( value )
nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM.set_output_embeddings( new_embeddings )
class nemo_automodel.components.models.mistral3.model.Ministral3MLP( config: nemo_automodel.components.models.mistral3.model.Ministral3Config )
Bases: Module
act_fn
= ACT2FN[config.hidden_act]
down_proj
gate_proj
hidden_size
= config.hidden_size
intermediate_size
= config.intermediate_size
up_proj
nemo_automodel.components.models.mistral3.model.Ministral3MLP.forward( x )
class nemo_automodel.components.models.mistral3.model.Ministral3Model( config: nemo_automodel.components.models.mistral3.model.Ministral3Config )
Bases: Ministral3PreTrainedModel
embed_tokens
layers
norm
padding_idx
= config.pad_token_id
rotary_emb
= Ministral3RotaryEmbedding(config=config)
vocab_size
= config.vocab_size
nemo_automodel.components.models.mistral3.model.Ministral3Model.forward( input_ids: typing.Optional[torch.LongTensor] = None, attention_mask: typing.Optional[torch.Tensor] = None, position_ids: typing.Optional[torch.LongTensor] = None, past_key_values: typing.Optional[transformers.cache_utils.Cache] = None, inputs_embeds: typing.Optional[torch.FloatTensor] = None, use_cache: typing.Optional[bool] = None, cache_position: typing.Optional[torch.LongTensor] = None, kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs] = {} ) -> transformers.modeling_outputs.BaseModelOutputWithPast
class nemo_automodel.components.models.mistral3.model.Ministral3ModelOutputWithPast( image_hidden_states: typing.Optional[torch.FloatTensor] = None )
Dataclass
Bases: BaseModelOutputWithPast
image_hidden_states
Optional[FloatTensor] = None
class nemo_automodel.components.models.mistral3.model.Ministral3PreTrainedModel()
Bases: PreTrainedModel
_can_record_outputs
= {}
_no_split_modules
= ['Ministral3DecoderLayer']
_skip_keys_device_placement
= ['past_key_values']
base_model_prefix
= 'model'
config
Ministral3Config
class nemo_automodel.components.models.mistral3.model.Ministral3RMSNorm( hidden_size, eps = 1e-06 )
Bases: Module
weight
= nn.Parameter(torch.ones(hidden_size))
nemo_automodel.components.models.mistral3.model.Ministral3RMSNorm.forward( hidden_states )
class nemo_automodel.components.models.mistral3.model.Ministral3RotaryEmbedding( config: nemo_automodel.components.models.mistral3.model.Ministral3Config, device = None )
Bases: Module
inv_freq
Tensor
max_seq_len_cached
= config.max_position_embeddings
original_max_seq_len
= config.max_position_embeddings
rope_type
nemo_automodel.components.models.mistral3.model.Ministral3RotaryEmbedding.compute_default_rope_parameters( config: typing.Optional[nemo_automodel.components.models.mistral3.model.Ministral3Config] = None, device: typing.Optional[torch.device] = None, seq_len: typing.Optional[int] = None ) -> tuple[torch.Tensor, float]
staticmethod
nemo_automodel.components.models.mistral3.model.Ministral3RotaryEmbedding.forward( x, position_ids )
nemo_automodel.components.models.mistral3.model._get_llama_4_attn_scale( positions_ids: torch.Tensor, beta: float, max_position_embeddings: int ) -> torch.Tensor
nemo_automodel.components.models.mistral3.model._register_ministral3_with_transformers()
Register Ministral3Config and models with transformers Auto classes.
This uses the official transformers registration API. Registration is idempotent (re-registering the same config/model is a no-op in recent transformers versions).
nemo_automodel.components.models.mistral3.model.apply_rotary_pos_emb( q, k, cos, sin, position_ids = None, unsqueeze_dim = 1 )
nemo_automodel.components.models.mistral3.model.eager_attention_forward( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: typing.Optional[torch.Tensor], scaling: float, dropout: float = 0.0, kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs] = {} )
nemo_automodel.components.models.mistral3.model.repeat_kv( hidden_states: torch.Tensor, n_rep: int ) -> torch.Tensor
nemo_automodel.components.models.mistral3.model.rotate_half( x )
nemo_automodel.components.models.mistral3.model.ModelClass = Ministral3ForCausalLM
nemo_automodel.components.models.mistral3.model.logger = logging.get_logger(__name__)