nemo_automodel.components.models.minimax_m3_vl.mtp

View as Markdown

MiniMax M3 multi-token prediction (MTP), DeepSeek-V3 style.

The checkpoint carries a single MTP module under model.mtp.layers.0: enorm/hnorm (Gemma RMSNorm of the next-token embedding and the previous hidden state), eh_proj (Linear 2*hidden -> hidden over their concatenation), a full MoE+sparse decoder transformer_layer, and final_layernorm. There is no separate output projection — the prediction head is the shared main lm_head.

sglang skips MTP at load (inference-only); the reference is the DeepSeek-V3 MTP algorithm.

Module Contents

Classes

NameDescription
MiniMaxM3MTPStack of MTP depths (M3 ships a single depth).
MiniMaxM3MTPBlockOne MTP depth: eh_proj(cat[enorm(emb), hnorm(h)]) -> Block -> final_layernorm.

API

class nemo_automodel.components.models.minimax_m3_vl.mtp.MiniMaxM3MTP(
config: typing.Any,
moe_config: nemo_automodel.components.moe.layers.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
num_modules: int
)

Bases: Module

Stack of MTP depths (M3 ships a single depth).

layers
nemo_automodel.components.models.minimax_m3_vl.mtp.MiniMaxM3MTP.forward(
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
embed_fn,
lm_head: torch.nn.Module,
freqs_cis: torch.Tensor,
block_kwargs: typing.Any = {}
) -> list[torch.Tensor]

Return per-depth next-token-+k logits using the shared lm_head.

nemo_automodel.components.models.minimax_m3_vl.mtp.MiniMaxM3MTP.init_weights(
buffer_device: torch.device
) -> None
class nemo_automodel.components.models.minimax_m3_vl.mtp.MiniMaxM3MTPBlock(
config: typing.Any,
moe_config: nemo_automodel.components.moe.layers.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: Module

One MTP depth: eh_proj(cat[enorm(emb), hnorm(h)]) -> Block -> final_layernorm.

transformer_layer is a full M3 decoder block; it is constructed at the last decoder index (always MoE + sparse-attention in M3) so the shared :class:~...layers.Block builds the routed MoE and the sparse-attention indexer automatically.

eh_proj
enorm
final_layernorm
hnorm
transformer_layer
nemo_automodel.components.models.minimax_m3_vl.mtp.MiniMaxM3MTPBlock.forward(
hidden_states: torch.Tensor,
embed_input: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.minimax_m3_vl.mtp.MiniMaxM3MTPBlock.init_weights(
buffer_device: torch.device
) -> None