nemo_automodel.components.loss.mtp
nemo_automodel.components.loss.mtp
Module Contents
Classes
Functions
API
Typed config for the Multi-Token-Prediction auxiliary loss.
MTP is gated on the model emitting per-depth outputs; this config only
carries its hyperparameters. scaling_factor=None keeps the
model-provided value (out.mtp_loss_scaling_factor /
get_mtp_loss_scaling_factor); set it to override.
Build the pipeline-schedule, MTP-aware loss for loss_fn/model.
Bases: Module
Pipeline schedule loss that can add MTP auxiliary CE on the last stage.
Per-microbatch seq_idx is read from a trailing element of the
last-stage output tuple — the model appends an [B, S] int32 tail
when MTP is enabled. This binds each microbatch’s seq_idx to its loss
call via the PP runtime’s output→loss contract, so the wiring is
schedule-agnostic. Legacy cu_seqlens (THD path) is a fallback for
models that don’t emit a seq_idx tail.
Detect and strip a trailing per-microbatch seq_idx from output.
Convention: with MTP enabled the last-stage output is
(logits, *mtp_per_depth_h, seq_idx) with an [B, S] int32
tail — dtype alone discriminates.
Compute the DeepSeek-V3 Multi-Token Prediction auxiliary loss.
Each depth’s CE is dispatched through :func:calculate_loss with the
same loss class as the main path, so MTP inherits FusedLinearCrossEntropy
/ MaskedCrossEntropy memory and numerical characteristics.
Parameters:
Configured per-token loss class (same instance the main path uses).
Per-depth hidden states from the model’s MTP head,
one [B, S, H] tensor per depth.
Original (unshifted) labels.
The wrapped model; used to fetch the shared LM head when the loss class needs materialized logits (non-FusedLinearCE path).
Coefficient applied to the summed per-depth CE.
Total non-ignore label tokens (forwarded to the base loss for sum-reduction normalization).
Label value masked out of the CE loss for the trailing
k+1 rolled positions at depth k.
Optional cumulative sequence lengths [num_seqs+1]
(THD-pack layout). When supplied and seq_idx is not, builds
a per-token sub-sequence index via searchsorted. Without packing
this can be omitted.
Optional per-token sub-sequence index [B, S] (or [S]).
Equality classes are what matter; absolute values can be any
ints. Takes precedence over cu_seqlens. Used to mask label
rolls whose source position lies in a different sub-sequence.
Optional caller-materialized LM-head weight. Supplying this
lets the main loss and all MTP depths share one DTensor
full_tensor() gather on the FusedLinearCrossEntropy path.
Returns: torch.Tensor
Scalar MTP loss with autograd graph.