nemo_automodel.components.models.deepseek_v4.mtp#

DeepSeek V4 Multi-Token Prediction (MTP) blocks.

The released DSV4-Flash checkpoint stores MTP under mtp.{depth}.*. Each MTP depth mirrors the reference MTPBlock:

  • fuse the future-token embedding and the backbone HC stream with e_proj(embed) + h_proj(hidden);

  • run one HC-enabled DSV4 attention + MoE block;

  • collapse the HC stream with an MTP-local hc_head and norm before the shared LM head computes the auxiliary CE loss.

Module Contents#

Classes#

DeepseekV4MTPBlock

One DSV4 MTP depth.

DeepseekV4MTPModule

DSV4 MTP stack, one :class:DeepseekV4MTPBlock per prediction depth.

Functions#

build_mtp_config_from_hf

Build an MTPConfig from a DeepseekV4Config.

build_deepseek_v4_mtp

Construct DSV4 MTP blocks.

API#

class nemo_automodel.components.models.deepseek_v4.mtp.DeepseekV4MTPBlock(
config,
layer_idx: int,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype,
rotary_emb,
rotary_emb_compress,
)#

Bases: torch.nn.Module

One DSV4 MTP depth.

Parameters:
  • config – Main DSV4 config.

  • layer_idx – Global layer index used by the attention implementation.

  • moe_config – Shared MoE config.

  • backend – BackendConfig for kernels/modules.

  • dtype – Model dtype.

  • rotary_emb – Shared main rotary embedding module.

  • rotary_emb_compress – Shared compressor rotary embedding module.

Initialization

forward(
hidden_states: torch.Tensor,
*,
embed_input: torch.Tensor,
input_ids: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
**attn_kwargs,
) tuple[torch.Tensor, torch.Tensor]#

Run one MTP depth.

Parameters:
  • hidden_states – HC stream [B, S, hc_mult, H].

  • embed_input – Future-token embeddings [B, S, H].

Returns:

Tuple of (next_hc_stream, prediction_hidden) where prediction_hidden is [B, S, H] and should be projected by the shared LM head for the MTP loss.

init_weights(buffer_device: torch.device | None = None) None#
class nemo_automodel.components.models.deepseek_v4.mtp.DeepseekV4MTPModule(
config,
mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
dtype: torch.dtype,
rotary_emb,
rotary_emb_compress,
)#

Bases: torch.nn.Module

DSV4 MTP stack, one :class:DeepseekV4MTPBlock per prediction depth.

Initialization

property num_depths: int#
forward(
hidden_states: torch.Tensor,
input_ids: torch.LongTensor | None = None,
embed_fn=None,
embed_inputs: tuple[torch.Tensor, ...] | list[torch.Tensor] | None = None,
position_ids: torch.LongTensor | None = None,
**block_kwargs,
) list[torch.Tensor]#
nemo_automodel.components.models.deepseek_v4.mtp.build_mtp_config_from_hf(
config,
*,
loss_scaling_factor: float = 0.1,
) nemo_automodel.components.models.common.mtp.MTPConfig#

Build an MTPConfig from a DeepseekV4Config.

nemo_automodel.components.models.deepseek_v4.mtp.build_deepseek_v4_mtp(
config,
mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
dtype: torch.dtype,
rotary_emb,
rotary_emb_compress,
) nemo_automodel.components.models.deepseek_v4.mtp.DeepseekV4MTPModule#

Construct DSV4 MTP blocks.