nemo_automodel.components.models.mistral4.model#
Module Contents#
Classes#
MLA with Llama 4 attention scaling for Mistral 4. |
|
Block using Mistral4MLA instead of MLA. |
|
Functions#
Position-dependent attention scaling for long-context extrapolation (Llama 4 / Mistral 4). |
|
Build MoEConfig from a Mistral4 text config. |
Data#
API#
- nemo_automodel.components.models.mistral4.model._get_llama_4_attn_scale(
- position_ids: torch.Tensor,
- beta: float,
- max_position_embeddings: int,
Position-dependent attention scaling for long-context extrapolation (Llama 4 / Mistral 4).
- class nemo_automodel.components.models.mistral4.model.Mistral4MLA(
- config,
- backend: nemo_automodel.components.models.common.BackendConfig,
Bases:
nemo_automodel.components.models.deepseek_v3.layers.MLAMLA with Llama 4 attention scaling for Mistral 4.
Compared to DeepSeek V3 MLA, adds position-dependent scaling to q_pe after RoPE (llama_4_scaling_beta). RoPE itself uses the same complex-number approach as DSV3.
Initialization
- forward(
- x: torch.Tensor,
- freqs_cis: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- **attn_kwargs: Any,
- class nemo_automodel.components.models.mistral4.model.Mistral4Block(layer_idx, config, moe_config, backend)#
Bases:
nemo_automodel.components.models.deepseek_v3.model.BlockBlock using Mistral4MLA instead of MLA.
Initialization
- nemo_automodel.components.models.mistral4.model._build_moe_config(
- config,
Build MoEConfig from a Mistral4 text config.
- class nemo_automodel.components.models.mistral4.model.Mistral4Model(
- config,
- backend: nemo_automodel.components.models.common.BackendConfig,
- *,
- moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
Bases:
torch.nn.ModuleInitialization
- forward(
- input_ids: torch.Tensor | None = None,
- *,
- inputs_embeds: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- padding_mask: torch.Tensor | None = None,
- **attn_kwargs: Any,
- update_moe_gate_bias() None#
- init_weights(buffer_device: torch.device | None = None) None#
- class nemo_automodel.components.models.mistral4.model.Mistral4ForCausalLM(
- config,
- moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
- backend: nemo_automodel.components.models.common.BackendConfig | None = None,
- **kwargs,
Bases:
nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixin,torch.nn.Module,nemo_automodel.components.moe.fsdp_mixin.MoEFSDPSyncMixin- classmethod from_config(
- config,
- moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
- backend: nemo_automodel.components.models.common.BackendConfig | None = None,
- **kwargs,
- classmethod from_pretrained(
- pretrained_model_name_or_path: str,
- *model_args,
- **kwargs,
- get_input_embeddings()#
- set_input_embeddings(value)#
- get_output_embeddings()#
- set_output_embeddings(new_embeddings)#
- forward(
- input_ids: torch.Tensor,
- *,
- position_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- padding_mask: torch.Tensor | None = None,
- **attn_kwargs: Any,
- update_moe_gate_bias() None#
- initialize_weights(
- buffer_device: torch.device | None = None,
- dtype: torch.dtype = torch.bfloat16,
- nemo_automodel.components.models.mistral4.model.ModelClass#
None