nemo_automodel.components.models.deepseek_v4.mtp

View as Markdown

DeepSeek V4 Multi-Token Prediction (MTP) blocks.

The released DSV4-Flash checkpoint stores MTP under mtp.{depth}.*. Each MTP depth mirrors the reference MTPBlock:

  • fuse the future-token embedding and the backbone HC stream with e_proj(embed) + h_proj(hidden);
  • run one HC-enabled DSV4 attention + MoE block;
  • collapse the HC stream with an MTP-local hc_head and norm before the shared LM head computes the auxiliary CE loss.

Module Contents

Classes

NameDescription
DeepseekV4MTPBlockOne DSV4 MTP depth.
DeepseekV4MTPModuleDSV4 MTP stack, one :class:DeepseekV4MTPBlock per prediction depth.

Functions

NameDescription
build_deepseek_v4_mtpConstruct DSV4 MTP blocks.
build_mtp_config_from_hfBuild an MTPConfig from a DeepseekV4Config.

API

class nemo_automodel.components.models.deepseek_v4.mtp.DeepseekV4MTPBlock(
config,
layer_idx: int,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype,
rotary_emb,
rotary_emb_compress
)

Bases: Module

One DSV4 MTP depth.

Parameters:

config

Main DSV4 config.

layer_idx
int

Global layer index used by the attention implementation.

moe_config
MoEConfig

Shared MoE config.

backend
BackendConfig

BackendConfig for kernels/modules.

dtype
torch.dtype

Model dtype.

rotary_emb

Shared main rotary embedding module.

rotary_emb_compress

Shared compressor rotary embedding module.

attn_hc
= DeepseekV4HyperConnection(**hc_kwargs)
e_proj
enorm
ffn_hc
= DeepseekV4HyperConnection(**hc_kwargs)
h_proj
hc_head
hnorm
input_layernorm
mlp
= MoE(moe_config, backend)
norm
post_attention_layernorm
self_attn
nemo_automodel.components.models.deepseek_v4.mtp.DeepseekV4MTPBlock.forward(
hidden_states: torch.Tensor,
embed_input: torch.Tensor,
input_ids: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
attn_kwargs = {}
) -> tuple[torch.Tensor, torch.Tensor]

Run one MTP depth.

Parameters:

hidden_states
torch.Tensor

HC stream [B, S, hc_mult, H].

embed_input
torch.Tensor

Future-token embeddings [B, S, H].

Returns: torch.Tensor

Tuple of (next_hc_stream, prediction_hidden) where

nemo_automodel.components.models.deepseek_v4.mtp.DeepseekV4MTPBlock.init_weights(
buffer_device: torch.device | None = None
) -> None
class nemo_automodel.components.models.deepseek_v4.mtp.DeepseekV4MTPModule(
config,
mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
dtype: torch.dtype,
rotary_emb,
rotary_emb_compress
)

Bases: Module

DSV4 MTP stack, one :class:DeepseekV4MTPBlock per prediction depth.

layers
num_depths
int
nemo_automodel.components.models.deepseek_v4.mtp.DeepseekV4MTPModule.forward(
hidden_states: torch.Tensor,
input_ids: torch.LongTensor | None = None,
embed_fn = None,
embed_inputs: tuple[torch.Tensor, ...] | list[torch.Tensor] | None = None,
position_ids: torch.LongTensor | None = None,
block_kwargs = {}
) -> list[torch.Tensor]
nemo_automodel.components.models.deepseek_v4.mtp.build_deepseek_v4_mtp(
config,
mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
dtype: torch.dtype,
rotary_emb,
rotary_emb_compress
) -> nemo_automodel.components.models.deepseek_v4.mtp.DeepseekV4MTPModule

Construct DSV4 MTP blocks.

nemo_automodel.components.models.deepseek_v4.mtp.build_mtp_config_from_hf(
config,
loss_scaling_factor: float = 0.1
) -> nemo_automodel.components.models.common.mtp.MTPConfig

Build an MTPConfig from a DeepseekV4Config.