nemo_automodel.components.loss.mtp#

Module Contents#

Classes#

PipelineCausalLMLoss

Pipeline schedule loss that can add MTP auxiliary CE on the last stage.

Functions#

calculate_mtp_loss

Compute the DeepSeek-V3 Multi-Token Prediction auxiliary loss.

API#

nemo_automodel.components.loss.mtp.calculate_mtp_loss(
loss_fn,
*,
mtp_per_depth_h: list[torch.Tensor] | None = None,
mtp_per_depth_logits: list[torch.Tensor] | None = None,
labels: torch.Tensor,
model: torch.nn.Module,
scaling_factor: float = 0.1,
num_label_tokens: Optional[int] = None,
ignore_index: int = -100,
) torch.Tensor#

Compute the DeepSeek-V3 Multi-Token Prediction auxiliary loss.

Each depth’s CE is dispatched through :func:calculate_loss with the same loss class as the main path, so MTP inherits FusedLinearCrossEntropy / MaskedCrossEntropy memory and numerical characteristics.

Parameters:
  • loss_fn – Configured per-token loss class (same instance the main path uses).

  • mtp_per_depth_h – Per-depth hidden states from the model’s MTP head, one [B, S, H] tensor per depth.

  • labels – Original (unshifted) labels.

  • model – The wrapped model; used to fetch the shared LM head when the loss class needs materialized logits (non-FusedLinearCE path).

  • scaling_factor – Coefficient applied to the summed per-depth CE.

  • num_label_tokens – Total non-ignore label tokens (forwarded to the base loss for sum-reduction normalization).

  • ignore_index – Label value masked out of the CE loss for the trailing k+1 rolled positions at depth k.

Returns:

Scalar MTP loss with autograd graph.

class nemo_automodel.components.loss.mtp.PipelineCausalLMLoss(loss_fn: torch.nn.Module, model: torch.nn.Module)#

Bases: torch.nn.Module

Pipeline schedule loss that can add MTP auxiliary CE on the last stage.

Initialization

forward(output, labels: torch.Tensor) torch.Tensor#