nemo_automodel.components.models.step3p7.mtp#

Step3 Multi-Token Prediction blocks.

Step checkpoints store MTP depths after the main decoder layers as model.layers.{num_hidden_layers + depth}.*. Each depth has the same decoder block structure plus fusion modules (enorm, hnorm, eh_proj) and an MTP-local shared head under transformer.shared_head.

Module Contents#

Classes#

Step3p5MTPSharedHead

Per-depth Step MTP prediction head.

Step3p5MTPBlock

One Step MTP prediction depth.

Step3p5MTPModule

Stack of Step MTP depths.

Functions#

_get_indexed_value

_ensure_indexed

_make_mtp_block_config

Return a shallow config copy patched for a dense sliding-attention MTP layer.

build_mtp_config_from_hf

Build Step MTP runtime config from HF-style config fields.

build_step3p5_mtp

Construct Step MTP depths.

API#

nemo_automodel.components.models.step3p7.mtp._get_indexed_value(
values: Any,
index: int,
default: Any,
) Any#
nemo_automodel.components.models.step3p7.mtp._ensure_indexed(
values: Any,
index: int,
value: Any,
) list[Any]#
nemo_automodel.components.models.step3p7.mtp._make_mtp_block_config(
config: Any,
layer_idx: int,
depth: int,
) Any#

Return a shallow config copy patched for a dense sliding-attention MTP layer.

class nemo_automodel.components.models.step3p7.mtp.Step3p5MTPSharedHead(
config: Any,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype,
)#

Bases: torch.nn.Module

Per-depth Step MTP prediction head.

Initialization

forward(hidden_states: torch.Tensor) torch.Tensor#
init_weights(buffer_device: torch.device) None#
class nemo_automodel.components.models.step3p7.mtp.Step3p5MTPBlock(
config: Any,
layer_idx: int,
depth: int,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype,
)#

Bases: nemo_automodel.components.models.step3p5.model.Block

One Step MTP prediction depth.

Initialization

forward(
hidden_states: torch.Tensor,
*,
embed_input: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
**attn_kwargs: Any,
) tuple[torch.Tensor, torch.Tensor]#
init_weights(buffer_device: torch.device) None#
class nemo_automodel.components.models.step3p7.mtp.Step3p5MTPModule(
config: Any,
mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
dtype: torch.dtype,
)#

Bases: torch.nn.Module

Stack of Step MTP depths.

Initialization

property num_depths: int#
forward(
hidden_states: torch.Tensor,
*,
freqs_cis: torch.Tensor,
input_ids: torch.LongTensor | None = None,
embed_fn=None,
embed_inputs: tuple[torch.Tensor, ...] | list[torch.Tensor] | None = None,
position_ids: torch.LongTensor | None = None,
**block_kwargs: Any,
) list[torch.Tensor]#
nemo_automodel.components.models.step3p7.mtp.build_mtp_config_from_hf(
config: Any,
*,
loss_scaling_factor: float = 0.1,
) nemo_automodel.components.models.common.mtp.MTPConfig#

Build Step MTP runtime config from HF-style config fields.

nemo_automodel.components.models.step3p7.mtp.build_step3p5_mtp(
config: Any,
mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
dtype: torch.dtype,
) nemo_automodel.components.models.step3p7.mtp.Step3p5MTPModule#

Construct Step MTP depths.