nemo_automodel.components.models.common.mtp.mtp#

Model-agnostic MTP scaffolding: depth iteration, token rolling, and loss.

Module Contents#

Classes#

MTPConfig

Runtime configuration for the MTP block.

MTPModule

Multi-Token Prediction block.

Functions#

roll_tensor

Roll a tensor along dim by shifts and zero the wrapped slice.

get_mtp_loss_scaling_factor

Return the model’s configured MTP auxiliary-loss scaling factor.

API#

nemo_automodel.components.models.common.mtp.mtp.roll_tensor(
t: torch.Tensor,
shifts: int = -1,
dim: int = -1,
) torch.Tensor#

Roll a tensor along dim by shifts and zero the wrapped slice.

Used to shift input_ids / position_ids / labels left by one position per MTP depth. Single-GPU path only (no CP / packed-sequence handling).

Parameters:
  • t – Input tensor.

  • shifts – Number of positions to shift (negative = left shift).

  • dim – Dimension to roll along.

Returns:

New tensor with the trailing |shifts| positions along dim zero-filled (i.e. no real wrap-around).

nemo_automodel.components.models.common.mtp.mtp.get_mtp_loss_scaling_factor(
model: torch.nn.Module,
default: float = 0.1,
) float#

Return the model’s configured MTP auxiliary-loss scaling factor.

class nemo_automodel.components.models.common.mtp.mtp.MTPConfig#

Runtime configuration for the MTP block.

.. attribute:: num_layers

Number of MTP forward iterations (D). 0 disables MTP. Equivalent to Megatron’s --mtp-num-layers.

.. attribute:: layer_pattern

Per-depth inner-block pattern, e.g. "*E" for one attention + one MoE sublayer per depth.

.. attribute:: loss_scaling_factor

Coefficient applied to the summed per-depth CE loss (default 0.1). The effective per-depth weight is loss_scaling_factor / num_layers.

.. attribute:: use_repeated_layer

When True, build a single physical depth’s worth of sublayers and reuse it for all num_layers forward iterations (weight-tied across depths). Equivalent to Megatron’s --mtp-use-repeated-layer.

num_layers: int#

0

layer_pattern: str = <Multiline-String>#
loss_scaling_factor: float#

0.1

use_repeated_layer: bool#

False

property pattern_length: int#
property num_physical_depths: int#
property total_sublayers: int#
property enabled: bool#
class nemo_automodel.components.models.common.mtp.mtp.MTPModule(
mtp_config: nemo_automodel.components.models.common.mtp.mtp.MTPConfig,
block_types_per_sublayer: list[str],
sublayer_factory: Callable[..., torch.nn.Module],
)#

Bases: torch.nn.Module

Multi-Token Prediction block.

Holds a flat :class:nn.ModuleList of sublayers (length num_physical_depths * pattern_length) where the first sublayer of each physical depth carries the fusion modules (enorm, hnorm, eh_proj) and the last sublayer of each physical depth carries final_layernorm. This flat layout matches the HuggingFace export format used by Nemotron-V3 (mtp.layers.{i}.*).

The model-specific sublayer construction (which decoder block to use, how to handle MoE / attention / Mamba) is delegated to the caller via sublayer_factory.

Parameters:
  • mtp_config –

    class:

    MTPConfig describing depth and pattern.

  • block_types_per_sublayer – List of block-type strings (one per inner sublayer position), length must equal mtp_config.pattern_length. Caller is responsible for parsing the model-specific symbol convention; this module does not interpret symbols.

  • sublayer_factory – Callable factory(global_idx, depth, sublayer_idx, block_type, has_fusion, has_final_norm) -> nn.Module constructing one sublayer. The returned module must be callable as sublayer(hidden_states, **kwargs) -> Tensor and, when has_fusion=True, expose attributes enorm, hnorm, eh_proj. When has_final_norm=True it must expose final_layernorm.

Initialization

property num_depths: int#
property pattern_length: int#
forward(
input_ids: torch.LongTensor,
hidden_states: torch.Tensor,
embed_fn: Callable[[torch.LongTensor], torch.Tensor],
position_ids: torch.LongTensor | None = None,
**block_kwargs,
) list[torch.Tensor]#

Iterate over MTP depths and return per-depth hidden states.

Parameters:
  • input_ids – Token ids [B, S] (or [T] in THD). Rolled cumulatively left by 1 per depth.

  • hidden_states – Output of the main model’s final norm (h_0); shape matches the model’s residual stream.

  • embed_fn – Callable applied to rolled input_ids to produce the future-token embedding (typically the model’s input embedding layer).

  • position_ids – Position ids matching input_ids. When supplied, rolled cumulatively per depth in lockstep with input_ids (so slot t carries the original position of the rolled token) and forwarded to each sublayer via block_kwargs. Required for RoPE-using sublayers; ignored by sublayers that don’t consume it.

  • **block_kwargs – Forwarded to each sublayer’s __call__ (e.g. attention_mask).

Returns:

List of length num_depths containing the hidden state produced at each depth.