nemo_automodel.components.models.deepseek_v4.mtp#
DeepSeek V4 Multi-Token Prediction (MTP) blocks.
The released DSV4-Flash checkpoint stores MTP under mtp.{depth}.*. Each
MTP depth mirrors the reference MTPBlock:
fuse the future-token embedding and the backbone HC stream with
e_proj(embed) + h_proj(hidden);run one HC-enabled DSV4 attention + MoE block;
collapse the HC stream with an MTP-local
hc_headandnormbefore the shared LM head computes the auxiliary CE loss.
Module Contents#
Classes#
One DSV4 MTP depth. |
|
DSV4 MTP stack, one :class: |
Functions#
Build an MTPConfig from a DeepseekV4Config. |
|
Construct DSV4 MTP blocks. |
API#
- class nemo_automodel.components.models.deepseek_v4.mtp.DeepseekV4MTPBlock(
- config,
- layer_idx: int,
- moe_config: nemo_automodel.components.moe.config.MoEConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- dtype: torch.dtype,
- rotary_emb,
- rotary_emb_compress,
Bases:
torch.nn.ModuleOne DSV4 MTP depth.
- Parameters:
config – Main DSV4 config.
layer_idx – Global layer index used by the attention implementation.
moe_config – Shared MoE config.
backend – BackendConfig for kernels/modules.
dtype – Model dtype.
rotary_emb – Shared main rotary embedding module.
rotary_emb_compress – Shared compressor rotary embedding module.
Initialization
- forward(
- hidden_states: torch.Tensor,
- *,
- embed_input: torch.Tensor,
- input_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- padding_mask: torch.Tensor | None = None,
- **attn_kwargs,
Run one MTP depth.
- Parameters:
hidden_states – HC stream
[B, S, hc_mult, H].embed_input – Future-token embeddings
[B, S, H].
- Returns:
Tuple of
(next_hc_stream, prediction_hidden)whereprediction_hiddenis[B, S, H]and should be projected by the shared LM head for the MTP loss.
- init_weights(buffer_device: torch.device | None = None) None#
- class nemo_automodel.components.models.deepseek_v4.mtp.DeepseekV4MTPModule(
- config,
- mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- moe_config: nemo_automodel.components.moe.config.MoEConfig,
- dtype: torch.dtype,
- rotary_emb,
- rotary_emb_compress,
Bases:
torch.nn.ModuleDSV4 MTP stack, one :class:
DeepseekV4MTPBlockper prediction depth.Initialization
- property num_depths: int#
- forward(
- hidden_states: torch.Tensor,
- input_ids: torch.LongTensor | None = None,
- embed_fn=None,
- embed_inputs: tuple[torch.Tensor, ...] | list[torch.Tensor] | None = None,
- position_ids: torch.LongTensor | None = None,
- **block_kwargs,
- nemo_automodel.components.models.deepseek_v4.mtp.build_mtp_config_from_hf(
- config,
- *,
- loss_scaling_factor: float = 0.1,
Build an MTPConfig from a DeepseekV4Config.
- nemo_automodel.components.models.deepseek_v4.mtp.build_deepseek_v4_mtp(
- config,
- mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- moe_config: nemo_automodel.components.moe.config.MoEConfig,
- dtype: torch.dtype,
- rotary_emb,
- rotary_emb_compress,
Construct DSV4 MTP blocks.