nemo_automodel.components.models.nemotron_v3.mtp

View as Markdown

NemotronV3-specific Multi-Token Prediction wiring.

Glue between the model-agnostic :mod:nemo_automodel.components.models.common.mtp scaffolding and the NemotronV3 decoder block. Each MTP sublayer is a :class:NemotronV3Block configured for the requested per-depth block type ("attention" or "moe") plus, when relevant, the depth-level fusion modules (enorm, hnorm, eh_proj) and final_layernorm.

The internal parameter naming mirrors HuggingFace’s flat mtp.layers.{global_idx}.* convention used by the released Super V3 checkpoint, so the state-dict adapter performs an effectively 1-to-1 mapping.

Module Contents

Classes

NameDescription
NemotronV3MTPSublayerOne MTP sublayer for NemotronV3.

Functions

NameDescription
_resolve_block_types_per_sublayerResolve the per-depth MTP block-type list from either HF field.
build_mtp_config_from_hfConstruct an :class:MTPConfig from an HF NemotronH config.
build_nemotron_v3_mtpConstruct the NemotronV3 MTP block.
parse_mtp_layer_patternParse a NemotronH MTP layer pattern (e.g. "*E") into block types.

Data

_PATTERN_SYMBOL_TO_BLOCK_TYPE

_VALID_BLOCK_TYPES

API

class nemo_automodel.components.models.nemotron_v3.mtp.NemotronV3MTPSublayer(
config,
layer_idx: int,
block_type: str,
moe_config = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
has_fusion: bool = False,
has_final_norm: bool = False,
dtype: torch.dtype = torch.bfloat16
)

Bases: NemotronV3Block

One MTP sublayer for NemotronV3.

Inherits :class:NemotronV3Block so it has the same norm + mixer

  • residual structure as a main-backbone layer; optionally adds the fusion modules (enorm/hnorm/eh_proj) on the first sublayer of each depth and final_layernorm on the last sublayer of each depth.
eh_proj
enorm
final_layernorm
hnorm
nemo_automodel.components.models.nemotron_v3.mtp.NemotronV3MTPSublayer.forward(
hidden_states: torch.Tensor,
embed_input: torch.Tensor | None = None,
kwargs = {}
) -> torch.Tensor

Run optional fusion (first sublayer of a depth), the base block, and optional final_layernorm (last sublayer of a depth).

Keeping the fusion + final-norm calls inside the sublayer’s own forward ensures FSDP2’s pre-forward unshard hook fires for every parameter we touch, so children like enorm/hnorm/eh_proj/final_layernorm are never accessed while their weights are still sharded DTensors.

nemo_automodel.components.models.nemotron_v3.mtp.NemotronV3MTPSublayer.init_weights(
buffer_device: torch.device | None = None
) -> None

Initialize sublayer weights, including fusion modules when present.

nemo_automodel.components.models.nemotron_v3.mtp._resolve_block_types_per_sublayer(
config
) -> list[str] | None

Resolve the per-depth MTP block-type list from either HF field.

Super-V3 ships mtp_hybrid_override_pattern (symbol-string form like "*E"). Newer NemotronH variants ship mtp_layers_block_type (list-of-strings form like ["attention", "moe"]). Either is accepted.

Parameters:

config

HF NemotronH config.

Returns: list[str] | None

Parsed list of block-type names, or None when neither field is set.

Raises:

  • ValueError: If mtp_layers_block_type contains an unknown block type.
nemo_automodel.components.models.nemotron_v3.mtp.build_mtp_config_from_hf(
config,
loss_scaling_factor: float = 0.1,
num_nextn_predict_layers: int | None = None,
use_repeated_layer: bool = False
) -> nemo_automodel.components.models.common.mtp.MTPConfig

Construct an :class:MTPConfig from an HF NemotronH config.

Reads num_nextn_predict_layers and resolves the per-depth pattern from either mtp_hybrid_override_pattern (Super-V3 symbol-string form) or mtp_layers_block_type (list-of-strings form). Returns a disabled config (num_layers=0) when MTP is not configured.

When the pattern source is the list form, :attr:MTPConfig.layer_pattern is set to a length-matching sentinel string of "X" characters — the actual block-type names are carried separately into :func:build_nemotron_v3_mtp via its block_types kwarg.

Parameters:

config

HF NemotronH config.

loss_scaling_factor
floatDefaults to 0.1

Auxiliary-loss weight applied to the summed per-depth CE (default 0.1). Not stored on the HF config; override programmatically when constructing the model.

num_nextn_predict_layers
int | NoneDefaults to None

Optional override for the HF config’s num_nextn_predict_layers field. When None, uses the value from config. Set explicitly when the trained model used weight-tied MTP iterations (use_repeated_layer=True) and the HF export only retains the physical depth count.

use_repeated_layer
boolDefaults to False

When True, build only one physical MTP depth and reuse it across all iterations. Mirrors Megatron’s --mtp-use-repeated-layer. Defaults to False.

Returns: MTPConfig

class:MTPConfig.

nemo_automodel.components.models.nemotron_v3.mtp.build_nemotron_v3_mtp(
config,
mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config,
dtype: torch.dtype,
block_types: list[str] | None = None
) -> nemo_automodel.components.models.common.mtp.MTPModule

Construct the NemotronV3 MTP block.

Parameters:

config

HF NemotronH config.

mtp_config
MTPConfig

Parsed MTP runtime config.

backend
BackendConfig

Backend configuration shared with the main backbone.

moe_config

MoE configuration shared with the main backbone (required when the MTP pattern contains MoE sublayers).

dtype
torch.dtype

Target dtype for newly created linear modules.

block_types
list[str] | NoneDefaults to None

Optional pre-parsed list of block-type names (one per inner sublayer). When supplied, bypasses :func:parse_mtp_layer_pattern on mtp_config.layer_pattern. Required when mtp_config.layer_pattern is a length-only sentinel (e.g. produced from mtp_layers_block_type).

Returns: MTPModule

A configured :class:MTPModule. Caller should not invoke this when

nemo_automodel.components.models.nemotron_v3.mtp.parse_mtp_layer_pattern(
pattern: str
) -> list[str]

Parse a NemotronH MTP layer pattern (e.g. "*E") into block types.

Parameters:

pattern
str

Pattern string using symbols M (mamba), * (attention), - (mlp), E (moe).

Returns: list[str]

List of block-type names ("mamba", "attention", "mlp", "moe").

Raises:

  • ValueError: If the pattern is empty or contains unknown symbols.
nemo_automodel.components.models.nemotron_v3.mtp._PATTERN_SYMBOL_TO_BLOCK_TYPE = {'M': 'mamba', '*': 'attention', '-': 'mlp', 'E': 'moe'}
nemo_automodel.components.models.nemotron_v3.mtp._VALID_BLOCK_TYPES = frozenset(_PATTERN_SYMBOL_TO_BLOCK_TYPE.values())