nemo_automodel.components.models.common.mtp.mtp#
Model-agnostic MTP scaffolding: depth iteration, token rolling, and loss.
Module Contents#
Classes#
Functions#
Roll a tensor along |
|
Return the model’s configured MTP auxiliary-loss scaling factor. |
API#
- nemo_automodel.components.models.common.mtp.mtp.roll_tensor(
- t: torch.Tensor,
- shifts: int = -1,
- dim: int = -1,
Roll a tensor along
dimbyshiftsand zero the wrapped slice.Used to shift
input_ids/position_ids/labelsleft by one position per MTP depth. Single-GPU path only (no CP / packed-sequence handling).- Parameters:
t – Input tensor.
shifts – Number of positions to shift (negative = left shift).
dim – Dimension to roll along.
- Returns:
New tensor with the trailing
|shifts|positions alongdimzero-filled (i.e. no real wrap-around).
- nemo_automodel.components.models.common.mtp.mtp.get_mtp_loss_scaling_factor(
- model: torch.nn.Module,
- default: float = 0.1,
Return the model’s configured MTP auxiliary-loss scaling factor.
- class nemo_automodel.components.models.common.mtp.mtp.MTPConfig#
Runtime configuration for the MTP block.
.. attribute:: num_layers
Number of MTP forward iterations (D).
0disables MTP. Equivalent to Megatron’s--mtp-num-layers... attribute:: layer_pattern
Per-depth inner-block pattern, e.g.
"*E"for one attention + one MoE sublayer per depth... attribute:: loss_scaling_factor
Coefficient applied to the summed per-depth CE loss (default
0.1). The effective per-depth weight isloss_scaling_factor / num_layers... attribute:: use_repeated_layer
When
True, build a single physical depth’s worth of sublayers and reuse it for allnum_layersforward iterations (weight-tied across depths). Equivalent to Megatron’s--mtp-use-repeated-layer.- num_layers: int#
0
- layer_pattern: str = <Multiline-String>#
- loss_scaling_factor: float#
0.1
- use_repeated_layer: bool#
False
- property pattern_length: int#
- property num_physical_depths: int#
- property total_sublayers: int#
- property enabled: bool#
- class nemo_automodel.components.models.common.mtp.mtp.MTPModule(
- mtp_config: nemo_automodel.components.models.common.mtp.mtp.MTPConfig,
- block_types_per_sublayer: list[str],
- sublayer_factory: Callable[..., torch.nn.Module],
Bases:
torch.nn.ModuleMulti-Token Prediction block.
Holds a flat :class:
nn.ModuleListof sublayers (lengthnum_physical_depths * pattern_length) where the first sublayer of each physical depth carries the fusion modules (enorm,hnorm,eh_proj) and the last sublayer of each physical depth carriesfinal_layernorm. This flat layout matches the HuggingFace export format used by Nemotron-V3 (mtp.layers.{i}.*).The model-specific sublayer construction (which decoder block to use, how to handle MoE / attention / Mamba) is delegated to the caller via
sublayer_factory.- Parameters:
mtp_config –
- class:
MTPConfigdescribing depth and pattern.
block_types_per_sublayer – List of block-type strings (one per inner sublayer position), length must equal
mtp_config.pattern_length. Caller is responsible for parsing the model-specific symbol convention; this module does not interpret symbols.sublayer_factory – Callable
factory(global_idx, depth, sublayer_idx, block_type, has_fusion, has_final_norm) -> nn.Moduleconstructing one sublayer. The returned module must be callable assublayer(hidden_states, **kwargs) -> Tensorand, whenhas_fusion=True, expose attributesenorm,hnorm,eh_proj. Whenhas_final_norm=Trueit must exposefinal_layernorm.
Initialization
- property num_depths: int#
- property pattern_length: int#
- forward(
- input_ids: torch.LongTensor,
- hidden_states: torch.Tensor,
- embed_fn: Callable[[torch.LongTensor], torch.Tensor],
- position_ids: torch.LongTensor | None = None,
- **block_kwargs,
Iterate over MTP depths and return per-depth hidden states.
- Parameters:
input_ids – Token ids
[B, S](or[T]in THD). Rolled cumulatively left by 1 per depth.hidden_states – Output of the main model’s final norm (
h_0); shape matches the model’s residual stream.embed_fn – Callable applied to rolled
input_idsto produce the future-token embedding (typically the model’s input embedding layer).position_ids – Position ids matching
input_ids. When supplied, rolled cumulatively per depth in lockstep withinput_ids(so slottcarries the original position of the rolled token) and forwarded to each sublayer viablock_kwargs. Required for RoPE-using sublayers; ignored by sublayers that don’t consume it.**block_kwargs – Forwarded to each sublayer’s
__call__(e.g.attention_mask).
- Returns:
List of length
num_depthscontaining the hidden state produced at each depth.