nemo_automodel.components.models.step3p7.mtp

View as Markdown

Step3 Multi-Token Prediction blocks.

Step checkpoints store MTP depths after the main decoder layers as model.layers.{num_hidden_layers + depth}.*. Each depth has the same decoder block structure plus fusion modules (enorm, hnorm, eh_proj) and an MTP-local shared head under transformer.shared_head.

Module Contents

Classes

NameDescription
Step3p5MTPBlockOne Step MTP prediction depth.
Step3p5MTPModuleStack of Step MTP depths.
Step3p5MTPSharedHeadPer-depth Step MTP prediction head.

Functions

NameDescription
_ensure_indexed-
_get_indexed_value-
_make_mtp_block_configReturn a shallow config copy patched for a dense sliding-attention MTP layer.
build_mtp_config_from_hfBuild Step MTP runtime config from HF-style config fields.
build_step3p5_mtpConstruct Step MTP depths.

API

class nemo_automodel.components.models.step3p7.mtp.Step3p5MTPBlock(
config: typing.Any,
layer_idx: int,
depth: int,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype
)

Bases: Block

One Step MTP prediction depth.

eh_proj
enorm
hnorm
transformer
= nn.Module()
nemo_automodel.components.models.step3p7.mtp.Step3p5MTPBlock.forward(
hidden_states: torch.Tensor,
embed_input: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.models.step3p7.mtp.Step3p5MTPBlock.init_weights(
buffer_device: torch.device
) -> None
class nemo_automodel.components.models.step3p7.mtp.Step3p5MTPModule(
config: typing.Any,
mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
dtype: torch.dtype
)

Bases: Module

Stack of Step MTP depths.

layers
num_depths
int
nemo_automodel.components.models.step3p7.mtp.Step3p5MTPModule.forward(
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
input_ids: torch.LongTensor | None = None,
embed_fn = None,
embed_inputs: tuple[torch.Tensor, ...] | list[torch.Tensor] | None = None,
position_ids: torch.LongTensor | None = None,
block_kwargs: typing.Any = {}
) -> list[torch.Tensor]
class nemo_automodel.components.models.step3p7.mtp.Step3p5MTPSharedHead(
config: typing.Any,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype
)

Bases: Module

Per-depth Step MTP prediction head.

norm
output
nemo_automodel.components.models.step3p7.mtp.Step3p5MTPSharedHead.forward(
hidden_states: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.models.step3p7.mtp.Step3p5MTPSharedHead.init_weights(
buffer_device: torch.device
) -> None
nemo_automodel.components.models.step3p7.mtp._ensure_indexed(
values: typing.Any,
index: int,
value: typing.Any
) -> list[typing.Any]
nemo_automodel.components.models.step3p7.mtp._get_indexed_value(
values: typing.Any,
index: int,
default: typing.Any
) -> typing.Any
nemo_automodel.components.models.step3p7.mtp._make_mtp_block_config(
config: typing.Any,
layer_idx: int,
depth: int
) -> typing.Any

Return a shallow config copy patched for a dense sliding-attention MTP layer.

nemo_automodel.components.models.step3p7.mtp.build_mtp_config_from_hf(
config: typing.Any,
loss_scaling_factor: float = 0.1
) -> nemo_automodel.components.models.common.mtp.MTPConfig

Build Step MTP runtime config from HF-style config fields.

nemo_automodel.components.models.step3p7.mtp.build_step3p5_mtp(
config: typing.Any,
mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
dtype: torch.dtype
) -> nemo_automodel.components.models.step3p7.mtp.Step3p5MTPModule

Construct Step MTP depths.