core.transformer.multi_token_prediction#

Module Contents#

Classes#

MTPLossLoggingHelper

Helper class for logging MTP losses.

MultiTokenPredictionLayerSubmodules

Dataclass for specifying the submodules of a MultiTokenPrediction module.

MTPLossAutoScaler

An AutoScaler that triggers the backward pass and scales the grad for mtp loss.

MultiTokenPredictionLayer

The implementation for Multi-Token Prediction (MTP) which extends the prediction scope to multiple future tokens at each position.

MultiTokenPredictionBlockSubmodules

Dataclass for specifying the submodules of a multi token prediction block.

MultiTokenPredictionBlock

The implementation for Multi-Token Prediction (MTP) which extends the prediction scope to multiple future tokens at each position.

Functions#

tie_word_embeddings_state_dict

tie the embedding of the mtp processing stage in a given sharded state dict.

tie_output_layer_state_dict

tie the output layer of the mtp processing stage in a given sharded state dict.

roll_tensor

Roll the tensor input along the sequence dimension with Context Parallelism (CP) support.

_roll_tensor_packed_seq

Roll tensor with packed sequence support. This function handles rolling for packed sequences by respecting sequence boundaries

get_mtp_layer_spec

Get the MTP layer spec.

get_mtp_layer_spec_for_backend

Get the MTP layer spec.

get_mtp_layer_offset

Get the offset of the MTP layer.

get_mtp_num_layers_to_build

Get the number of MTP layers to build.

_get_mtp_block_submodules

Retrieve or construct MultiTokenPredictionBlockSubmodules based on the provided specification.

Data#

API#

core.transformer.multi_token_prediction.SUPPORTED_ATTN_MASK#

None

core.transformer.multi_token_prediction.tie_word_embeddings_state_dict(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
word_emb_weight: torch.Tensor,
word_emb_weight_key: str,
tp_group: torch.distributed.ProcessGroup = None,
dp_cp_group: torch.distributed.ProcessGroup = None,
) None#

tie the embedding of the mtp processing stage in a given sharded state dict.

Parameters:
  • sharded_state_dict (ShardedStateDict) – state dict with the weight to tie.

  • word_emb_weight (Tensor) – weight of the word embedding.

  • word_emb_weight_key (str) – key of the word embedding in the sharded state dict.

  • tp_group (torch.distributed.ProcessGroup) – The tensor parallel group

  • dp_cp_group (torch.distributed.ProcessGroup) – The dp-cp comm group

Returns: None, acts in-place

core.transformer.multi_token_prediction.tie_output_layer_state_dict(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
output_layer_weight: torch.Tensor,
output_layer_weight_key: str,
tp_group: torch.distributed.ProcessGroup = None,
dp_cp_group: torch.distributed.ProcessGroup = None,
) None#

tie the output layer of the mtp processing stage in a given sharded state dict.

Parameters:
  • sharded_state_dict (ShardedStateDict) – state dict with the weight to tie.

  • output_layer_weight (Tensor) – weight of the output layer.

  • output_layer_weight_key (str) – key of the output layer in the sharded state dict.

  • tp_group (torch.distributed.ProcessGroup) – The tensor parallel group

  • dp_cp_group (torch.distributed.ProcessGroup) – The dp-cp comm group

Returns: None, acts in-place

core.transformer.multi_token_prediction.roll_tensor(
tensor,
shifts=-1,
dims=-1,
cp_group=None,
packed_seq_params=None,
)#

Roll the tensor input along the sequence dimension with Context Parallelism (CP) support.

This function extends the original roll_tensor to support Context Parallelism, which allows MTP to work with CP > 1. When CP is enabled, the sequence dimension is split across CP ranks, and tensor rolling requires communication between adjacent CP ranks to properly handle the boundary conditions.

For CP=1 (default behavior): Uses standard torch.roll with zero padding For CP>1: Splits tensor into chunks, performs rolling within each chunk, then exchanges boundary elements between adjacent CP ranks to maintain sequence continuity.

For packed sequences: Respects sequence boundaries when rolling to avoid mixing tokens from different sequences.

Parameters:
  • tensor (Tensor) – The input tensor to roll.

  • shifts (int) – The shift of the tensor (typically -1 for MTP).

  • dims (int) – The dimension to roll (typically -1 for sequence dimension).

  • cp_group (ProcessGroup) – The context parallelism process group. If None or size=1, falls back to standard rolling behavior.

  • packed_seq_params (PackedSeqParams) – Parameters for packed sequence processing. If provided, respects sequence boundaries.

Returns:

(rolled_tensor, sum_of_rolled_tensor)

Return type:

tuple

core.transformer.multi_token_prediction._roll_tensor_packed_seq(
tensor,
shifts,
dims,
packed_seq_params,
cp_group=None,
)#

Roll tensor with packed sequence support. This function handles rolling for packed sequences by respecting sequence boundaries

class core.transformer.multi_token_prediction.MTPLossLoggingHelper#

Helper class for logging MTP losses.

tracker#

None

static save_loss_to_tracker(
loss: torch.Tensor,
layer_number: int,
num_layers: int,
reduce_group: torch.distributed.ProcessGroup = None,
avg_group: torch.distributed.ProcessGroup = None,
)#

Save the mtp loss for logging.

Parameters:
  • loss (torch.Tensor) – The loss tensor.

  • layer_number (int) – Layer index of the loss.

  • num_layers (int) – The number of total layers.

  • reduce_group (torch.distributed.ProcessGroup) – The group for reducing the loss.

  • mean_group (torch.distributed.ProcessGroup) – The group for averaging the loss.

clean_loss_in_tracker()#

Clear the mtp losses.

reduce_loss_in_tracker()#

Collect and reduce the mtp losses across ranks.

track_mtp_metrics(
iteration,
writer,
wandb_writer=None,
total_loss_dict=None,
)#

Track the Multi-Token Prediction (MTP) metrics for logging.

class core.transformer.multi_token_prediction.MultiTokenPredictionLayerSubmodules#

Dataclass for specifying the submodules of a MultiTokenPrediction module.

Parameters:
  • hnorm (Union[ModuleSpec, type]) – Specification or instance of the hidden states normalization to be applied.

  • enorm (Union[ModuleSpec, type]) – Specification or instance of the embedding normalization to be applied.

  • eh_proj (Union[ModuleSpec, type]) – Specification or instance of the linear projection to be applied.

  • transformer_layer (Union[ModuleSpec, type]) – Specification or instance of the transformer block to be applied.

enorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

hnorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

eh_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

transformer_layer: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

layer_norm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

core.transformer.multi_token_prediction.get_mtp_layer_spec(
transformer_layer_spec: megatron.core.transformer.spec_utils.ModuleSpec,
use_transformer_engine: bool,
) megatron.core.transformer.spec_utils.ModuleSpec#

Get the MTP layer spec.

Returns:

Module specification with TE modules

Return type:

ModuleSpec

core.transformer.multi_token_prediction.get_mtp_layer_spec_for_backend(
transformer_layer_spec: megatron.core.transformer.spec_utils.ModuleSpec,
backend: megatron.core.models.backends.BackendSpecProvider,
) megatron.core.transformer.spec_utils.ModuleSpec#

Get the MTP layer spec.

Returns:

Module specification with modules from the backend.

Return type:

ModuleSpec

core.transformer.multi_token_prediction.get_mtp_layer_offset(
config: megatron.core.transformer.transformer_config.TransformerConfig,
) int#

Get the offset of the MTP layer.

core.transformer.multi_token_prediction.get_mtp_num_layers_to_build(
config: megatron.core.transformer.transformer_config.TransformerConfig,
vp_stage: Optional[int] = None,
pp_rank: Optional[int] = None,
) int#

Get the number of MTP layers to build.

class core.transformer.multi_token_prediction.MTPLossAutoScaler#

Bases: torch.autograd.Function

An AutoScaler that triggers the backward pass and scales the grad for mtp loss.

main_loss_backward_scale: torch.Tensor#

‘tensor(…)’

static forward(ctx, output: torch.Tensor, mtp_loss: torch.Tensor)#

Preserve the mtp by storing it in the context to avoid garbage collection.

Parameters:
  • output (torch.Tensor) – The output tensor.

  • mtp_loss (torch.Tensor) – The mtp loss tensor.

Returns:

The output tensor.

Return type:

torch.Tensor

static backward(ctx, grad_output: torch.Tensor)#

Compute and scale the gradient for mtp loss..

Parameters:

grad_output (torch.Tensor) – The gradient of the output.

Returns:

The gradient of the output, scaled mtp loss gradient.

Return type:

Tuple[torch.Tensor, torch.Tensor]

static set_loss_scale(scale: torch.Tensor)#

set the scale of the mtp loss.

Parameters:

scale (torch.Tensor) – The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss.

class core.transformer.multi_token_prediction.MultiTokenPredictionLayer(
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: core.transformer.multi_token_prediction.MultiTokenPredictionLayerSubmodules,
layer_number: int = 1,
vp_stage: Optional[int] = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

The implementation for Multi-Token Prediction (MTP) which extends the prediction scope to multiple future tokens at each position.

This MTP implementation sequentially predict additional tokens and keep the complete causal chain at each prediction depth, by using D sequential modules to predict D additional tokens.

The k-th MTP module consists of a shared embedding layer, a projection matrix, a Transformer block, and a shared output head.

For the i-th input token at the (k - 1)-th prediction depth, we first combine the representation of the i-th token and the embedding of the (i + K)-th token with the linear projection. The combined serves as the input of the Transformer block at the k-th depth to produce the output representation.

for more information, please refer to DeepSeek-V3 Technical Report https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf

Initialization

_get_embeddings(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
embedding: Callable,
hidden_states: torch.Tensor,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
)#

Preprocesses input data for the Multi-Token Prediction (MTP) layers.

This function computes the decoder input and sends updated input_ids and position_ids to the next layer.

Parameters:
  • input_ids (torch.Tensor) – The input token IDs.

  • position_ids (torch.Tensor) – The position IDs corresponding to the input tokens.

  • embedding (Callable) – The embedding module from gpt model to compute the decoder input.

  • hidden_states (torch.Tensor) – hidden states tensor of shape [s, b, h] where s is the sequence length, b is the batch size, and h is the hidden size.

  • packed_seq_params (PackedSeqParams) – Parameters for packed sequence processing.

_concat_embeddings(
hidden_states: torch.Tensor,
decoder_input: torch.Tensor,
)#

Concatenate the tokens before sending to transformer layer.

_proj_and_transformer_layer(
hidden_states: torch.Tensor,
decoder_input: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
inference_params: Optional[megatron.core.InferenceParams] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
sequence_len_offset: Optional[torch.Tensor] = None,
) torch.Tensor#

Concatenates embeddings with hidden states and then applies transformer layer forward.

_postprocess(hidden_states: torch.Tensor)#

Postprocesses the output of the transformer layers.

_checkpointed_forward(forward_func, *args, **kwargs)#
forward(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
context: torch.Tensor = None,
context_mask: torch.Tensor = None,
rotary_pos_emb: torch.Tensor = None,
rotary_pos_cos: torch.Tensor = None,
rotary_pos_sin: torch.Tensor = None,
attention_bias: torch.Tensor = None,
inference_params: megatron.core.InferenceParams = None,
packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams = None,
sequence_len_offset: torch.Tensor = None,
embedding=None,
)#

Execute the forward pass through the Multi-Token Prediction (MTP) layer.

Parameters:
  • input_ids (Tensor) – Input token IDs .

  • position_ids (Tensor) – Positional IDs of the input tokens.

  • hidden_states (Tensor) – Hidden states tensor of shape [s, b, h] where s is the sequence length, b is the batch size, and h is the hidden size.

  • attention_mask (Tensor) – Boolean tensor of shape [1, 1, s, s] for masking self-attention.

  • context (Tensor, optional) – Context tensor for cross-attention, if applicable.

  • context_mask (Tensor, optional) – Mask for cross-attention context, if applicable.

  • rotary_pos_emb (Tensor, optional) – Rotary positional embeddings.

  • rotary_pos_cos (Tensor, optional) – Cosine component of rotary positional embeddings.

  • rotary_pos_sin (Tensor, optional) – Sine component of rotary positional embeddings.

  • sequence_len_offset (Tensor, optional) – Offset for sequence length, if applicable.

  • embedding (Callable) – The embedding module from gpt model to compute the decoder input.

Returns:

The output hidden states tensor of shape [s, b, h], and optionally the updated context tensor if cross-attention is used.

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

sharded_state_dict(
prefix: str = '',
sharded_offsets: tuple = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Generate a sharded state dictionary for the multi token prediction layer.

Parameters:
  • prefix (str, optional) – Prefix to be added to all keys in the state dict.

  • sharded_offsets (tuple, optional) – Tuple of sharding offsets.

  • metadata (Optional[dict], optional) – Additional metadata for sharding.

Returns:

A dictionary containing the sharded state of the multi token prediction layer.

Return type:

ShardedStateDict

class core.transformer.multi_token_prediction.MultiTokenPredictionBlockSubmodules#

Dataclass for specifying the submodules of a multi token prediction block.

This class defines the structure for configuring the layers, allowing for flexible and customizable architecture designs.

Parameters:

layer_specs (List[ModuleSpec], optional) – A list of module specifications for the layers within the multi token prediction block. Each specification typically defines a complete multi token prediction layer (e.g., shared embedding, projection matrix, transformer block, shared output head).

layer_specs: List[megatron.core.transformer.spec_utils.ModuleSpec]#

None

core.transformer.multi_token_prediction._get_mtp_block_submodules(
config: megatron.core.transformer.transformer_config.TransformerConfig,
spec: Union[core.transformer.multi_token_prediction.MultiTokenPredictionBlockSubmodules, megatron.core.transformer.spec_utils.ModuleSpec],
) core.transformer.multi_token_prediction.MultiTokenPredictionBlockSubmodules#

Retrieve or construct MultiTokenPredictionBlockSubmodules based on the provided specification.

Parameters:
Returns:

The submodules for the multi token prediction block.

Return type:

MultiTokenPredictionBlockSubmodules

class core.transformer.multi_token_prediction.MultiTokenPredictionBlock(
config: megatron.core.transformer.transformer_config.TransformerConfig,
spec: Union[megatron.core.transformer.transformer_block.TransformerBlockSubmodules, megatron.core.transformer.spec_utils.ModuleSpec],
vp_stage: Optional[int] = None,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

The implementation for Multi-Token Prediction (MTP) which extends the prediction scope to multiple future tokens at each position.

This MTP implementation sequentially predict additional tokens and keep the complete causal chain at each prediction depth, by using D sequential modules to predict D additional tokens.

The k-th MTP module consists of a shared embedding layer, a projection matrix, a Transformer block, and a shared output head.

For the i-th input token at the (k - 1)-th prediction depth, we first combine the representation of the i-th token and the embedding of the (i + K)-th token with the linear projection. The combined serves as the input of the Transformer block at the k-th depth to produce the output representation.

for more information, please refer to DeepSeek-V3 Technical Report https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf

Initialization

_build_layers(pg_collection)#
forward(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
context: torch.Tensor = None,
context_mask: torch.Tensor = None,
rotary_pos_emb: torch.Tensor = None,
rotary_pos_cos: torch.Tensor = None,
rotary_pos_sin: torch.Tensor = None,
attention_bias: torch.Tensor = None,
inference_params: megatron.core.InferenceParams = None,
packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams = None,
sequence_len_offset: torch.Tensor = None,
extra_block_kwargs: dict = None,
embedding=None,
) torch.Tensor#

Perform the forward pass through all of the MTP modules.

Parameters:
  • hidden_states (Tensor) – Hidden states for input token with the shape [s, b, h] where s is the sequence length, b is the batch size, and h is the hidden size.

  • attention_mask (Tensor) – Boolean tensor of shape [1, 1, s, s] for masking self-attention.

Returns:

The mtp loss tensor of shape [b, s].

Return type:

(Tensor)

sharded_state_dict(
prefix: str = '',
sharded_offsets: tuple = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Generate a sharded state dictionary for the multi token prediction module.

Parameters:
  • prefix (str, optional) – Prefix to be added to all keys in the state dict.

  • sharded_offsets (tuple, optional) – Tuple of sharding offsets.

  • metadata (Optional[dict], optional) – Additional metadata for sharding.

Returns:

A dictionary containing the sharded state of the multi token prediction module.

Return type:

ShardedStateDict