nemo_automodel.components.models.minimax_m3_vl.model

View as Markdown

MiniMax M3 (mixed sparse/dense MoE) text backbone.

Stage 1 implements MiniMaxM3TextModel and the standalone MiniMaxM3SparseForCausalLM so the language path can be parity-tested against the sglang reference before the vision tower / VLM wrapper (Stage 3) embeds the text model as language_model.

Module Contents

Classes

NameDescription
MiniMaxM3CausalLMOutputForward output carrying the primary logits and optional per-depth MTP logits.
MiniMaxM3SparseForCausalLMStandalone M3 text backbone for causal LM (Stage 1 parity target).
MiniMaxM3SparseForConditionalGenerationMiniMax M3 VL: CLIP-style vision tower + projector/merger + M3 text backbone.
MiniMaxM3TextModelEmbedding + decoder stack + final norm for the M3 text backbone.

Functions

NameDescription
build_moe_configBuild the routed-expert MoEConfig for the M3 backbone.

Data

ModelClass

API

class nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3CausalLMOutput(
logits: torch.Tensor,
mtp_per_depth_logits: list[torch.Tensor] | None = None
)
Dataclass

Forward output carrying the primary logits and optional per-depth MTP logits.

logits
Tensor
mtp_per_depth_logits
list[Tensor] | None = None
class nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForCausalLM(
config: typing.Any,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)

Bases: HFCheckpointingMixin, Module, MoEFSDPSyncMixin

Standalone M3 text backbone for causal LM (Stage 1 parity target).

backend
= backend or BackendConfig()
lm_head
model
state_dict_adapter
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForCausalLM.forward(
input_ids: torch.Tensor,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForCausalLM.from_config(
config: typing.Any,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)
classmethod
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForCausalLM.from_pretrained(
pretrained_model_name_or_path: str,
model_args = (),
kwargs = {}
)
classmethod
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForCausalLM.get_input_embeddings()
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForCausalLM.get_output_embeddings()
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForCausalLM.initialize_weights(
buffer_device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16
) -> None
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForCausalLM.set_input_embeddings(
value
)
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForCausalLM.set_output_embeddings(
new_embeddings
)
class nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForConditionalGeneration(
config: typing.Any,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)

Bases: HFCheckpointingMixin, Module, MoEFSDPSyncMixin

MiniMax M3 VL: CLIP-style vision tower + projector/merger + M3 text backbone.

Vision features (vision_tower(pixel_values, grid_thw)) are spliced into the text embeddings at image_token_index / video_token_index positions, then run through the (sparse/dense MoE) language model + lm_head.

_keep_in_fp32_modules
= ['rotary_emb', 'inv_freq']
_pp_keep_self_forward
bool = True
backend
= backend or BackendConfig()
image_token_index
= config.image_token_index
lm_head
model
state_dict_adapter
video_token_index
= config.video_token_index
vision_tower
vocab_size
= text_config.vocab_size
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForConditionalGeneration._is_pipeline_parallel_stage() -> bool

True when this is a partial pipeline stage (some text modules nulled).

nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForConditionalGeneration._splice_multimodal(
inputs_embeds: torch.Tensor,
input_ids: torch.Tensor,
pixel_values: torch.Tensor,
grid_thw,
token_index: int
) -> torch.Tensor
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForConditionalGeneration._to_grid_list(
grid_thw
) -> list[list[int]]
staticmethod
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForConditionalGeneration.customize_pipeline_stage_modules(
module_names_per_stage: list[list[str]],
layers_prefix: str,
text_model: torch.nn.Module | None = None
) -> list[list[str]]

Rewrite auto-generated pipeline FQNs to M3’s real module paths.

M3’s text stack lives directly under self.model and the vision tower is a top-level sibling (vision_tower). The framework, seeing the language_model property, derives a nested model.language_model. prefix for the text modules and a model. prefix for the multimodal encoders. Map both back to M3’s actual paths so per-stage module nulling keeps/drops the correct submodules.

nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForConditionalGeneration.forward(
input_ids: torch.Tensor | None = None,
pixel_values: torch.Tensor | None = None,
image_grid_thw = None,
pixel_values_videos: torch.Tensor | None = None,
video_grid_thw = None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForConditionalGeneration.from_config(
config: typing.Any,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)
classmethod
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForConditionalGeneration.from_pretrained(
pretrained_model_name_or_path: str,
model_args = (),
kwargs = {}
)
classmethod
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForConditionalGeneration.get_input_embeddings()
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForConditionalGeneration.get_output_embeddings()
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForConditionalGeneration.get_pipeline_stage_metas(
is_first: bool,
microbatch_size: int,
seq_len: int,
dtype: torch.dtype
) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]

Per-stage input/output meta tensors for the PP schedule’s shape inference.

First stage consumes token ids [mb, seq]; later stages consume hidden states [mb, seq, hidden]. The final stage (owning lm_head) emits logits [mb, seq, vocab]; earlier stages emit hidden states.

nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForConditionalGeneration.initialize_weights(
buffer_device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16
) -> None
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3SparseForConditionalGeneration.set_input_embeddings(
value
)
class nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3TextModel(
config: typing.Any,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None
)

Bases: Module

Embedding + decoder stack + final norm for the M3 text backbone.

embed_tokens
head_dim
layers
= torch.nn.ModuleDict()
max_seq_len
= config.max_position_embeddings
moe_config
= moe_config or build_moe_config(config, dtype)
mtp
norm
rotary_emb
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3TextModel.forward(
input_ids: torch.Tensor | None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3TextModel.init_weights(
buffer_device: torch.device | None = None
) -> None
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3TextModel.make_freqs_cis(
position_ids: torch.Tensor,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.minimax_m3_vl.model.MiniMaxM3TextModel.mtp_logits(
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
lm_head: torch.nn.Module,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> list[torch.Tensor]

Per-depth MTP logits from the final hidden states (shares lm_head).

nemo_automodel.components.models.minimax_m3_vl.model.build_moe_config(
config: typing.Any,
dtype: torch.dtype
) -> nemo_automodel.components.moe.layers.MoEConfig

Build the routed-expert MoEConfig for the M3 backbone.

Shared experts are handled in :class:~...layers.Block (SwiGLU-OAI), so n_shared_experts is 0 here. Routed experts use the swigluoai activation gate * sigmoid(alpha * gate) * (up + 1) over the concatenated grouped gate/up projection produced by MoESplitExpertsStateDictMixin.

nemo_automodel.components.models.minimax_m3_vl.model.ModelClass = MiniMaxM3SparseForConditionalGeneration