nemo_automodel.components.models.mistral4.model#

Module Contents#

Classes#

Mistral4MLA

MLA with Llama 4 attention scaling for Mistral 4.

Mistral4Block

Block using Mistral4MLA instead of MLA.

Mistral4Model

Mistral4ForCausalLM

Functions#

_get_llama_4_attn_scale

Position-dependent attention scaling for long-context extrapolation (Llama 4 / Mistral 4).

_build_moe_config

Build MoEConfig from a Mistral4 text config.

Data#

API#

nemo_automodel.components.models.mistral4.model._get_llama_4_attn_scale(
position_ids: torch.Tensor,
beta: float,
max_position_embeddings: int,
) torch.Tensor#

Position-dependent attention scaling for long-context extrapolation (Llama 4 / Mistral 4).

class nemo_automodel.components.models.mistral4.model.Mistral4MLA(
config,
backend: nemo_automodel.components.models.common.BackendConfig,
)#

Bases: nemo_automodel.components.models.deepseek_v3.layers.MLA

MLA with Llama 4 attention scaling for Mistral 4.

Compared to DeepSeek V3 MLA, adds position-dependent scaling to q_pe after RoPE (llama_4_scaling_beta). RoPE itself uses the same complex-number approach as DSV3.

Initialization

forward(
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
)#
class nemo_automodel.components.models.mistral4.model.Mistral4Block(layer_idx, config, moe_config, backend)#

Bases: nemo_automodel.components.models.deepseek_v3.model.Block

Block using Mistral4MLA instead of MLA.

Initialization

nemo_automodel.components.models.mistral4.model._build_moe_config(
config,
) nemo_automodel.components.moe.config.MoEConfig#

Build MoEConfig from a Mistral4 text config.

class nemo_automodel.components.models.mistral4.model.Mistral4Model(
config,
backend: nemo_automodel.components.models.common.BackendConfig,
*,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
)#

Bases: torch.nn.Module

Initialization

forward(
input_ids: torch.Tensor | None = None,
*,
inputs_embeds: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) torch.Tensor#
update_moe_gate_bias() None#
init_weights(buffer_device: torch.device | None = None) None#
class nemo_automodel.components.models.mistral4.model.Mistral4ForCausalLM(
config,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
**kwargs,
)#

Bases: nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixin, torch.nn.Module, nemo_automodel.components.moe.fsdp_mixin.MoEFSDPSyncMixin

classmethod from_config(
config,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
**kwargs,
)#
classmethod from_pretrained(
pretrained_model_name_or_path: str,
*model_args,
**kwargs,
)#
get_input_embeddings()#
set_input_embeddings(value)#
get_output_embeddings()#
set_output_embeddings(new_embeddings)#
forward(
input_ids: torch.Tensor,
*,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) torch.Tensor#
update_moe_gate_bias() None#
initialize_weights(
buffer_device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16,
) None#
nemo_automodel.components.models.mistral4.model.ModelClass#

None