nemo_automodel.components.loss.mtp#
Module Contents#
Classes#
Pipeline schedule loss that can add MTP auxiliary CE on the last stage. |
Functions#
Compute the DeepSeek-V3 Multi-Token Prediction auxiliary loss. |
API#
- nemo_automodel.components.loss.mtp.calculate_mtp_loss(
- loss_fn,
- *,
- mtp_per_depth_h: list[torch.Tensor] | None = None,
- mtp_per_depth_logits: list[torch.Tensor] | None = None,
- labels: torch.Tensor,
- model: torch.nn.Module,
- scaling_factor: float = 0.1,
- num_label_tokens: Optional[int] = None,
- ignore_index: int = -100,
Compute the DeepSeek-V3 Multi-Token Prediction auxiliary loss.
Each depth’s CE is dispatched through :func:
calculate_losswith the same loss class as the main path, so MTP inherits FusedLinearCrossEntropy / MaskedCrossEntropy memory and numerical characteristics.- Parameters:
loss_fn – Configured per-token loss class (same instance the main path uses).
mtp_per_depth_h – Per-depth hidden states from the model’s MTP head, one
[B, S, H]tensor per depth.labels – Original (unshifted) labels.
model – The wrapped model; used to fetch the shared LM head when the loss class needs materialized logits (non-FusedLinearCE path).
scaling_factor – Coefficient applied to the summed per-depth CE.
num_label_tokens – Total non-ignore label tokens (forwarded to the base loss for sum-reduction normalization).
ignore_index – Label value masked out of the CE loss for the trailing
k+1rolled positions at depthk.
- Returns:
Scalar MTP loss with autograd graph.