nemo_automodel.components.loss.mtp

View as Markdown

Module Contents

Classes

NameDescription
MTPLossConfigTyped config for the Multi-Token-Prediction auxiliary loss.
PipelineCausalLMLossPipeline schedule loss that can add MTP auxiliary CE on the last stage.

Functions

NameDescription
calculate_mtp_lossCompute the DeepSeek-V3 Multi-Token Prediction auxiliary loss.

API

class nemo_automodel.components.loss.mtp.MTPLossConfig(
scaling_factor: float | None = None,
ignore_index: int = -100
)
Dataclass

Typed config for the Multi-Token-Prediction auxiliary loss.

MTP is gated on the model emitting per-depth outputs; this config only carries its hyperparameters. scaling_factor=None keeps the model-provided value (out.mtp_loss_scaling_factor / get_mtp_loss_scaling_factor); set it to override.

ignore_index
int = -100
scaling_factor
float | None = None
nemo_automodel.components.loss.mtp.MTPLossConfig.build(
loss_fn: torch.nn.Module,
model: torch.nn.Module
) -> nemo_automodel.components.loss.mtp.PipelineCausalLMLoss

Build the pipeline-schedule, MTP-aware loss for loss_fn/model.

class nemo_automodel.components.loss.mtp.PipelineCausalLMLoss(
loss_fn: torch.nn.Module,
model: torch.nn.Module,
scaling_factor: float | None = None,
ignore_index: int = -100
)

Bases: Module

Pipeline schedule loss that can add MTP auxiliary CE on the last stage.

Per-microbatch seq_idx is read from a trailing element of the last-stage output tuple — the model appends an [B, S] int32 tail when MTP is enabled. This binds each microbatch’s seq_idx to its loss call via the PP runtime’s output→loss contract, so the wiring is schedule-agnostic. Legacy cu_seqlens (THD path) is a fallback for models that don’t emit a seq_idx tail.

cu_seqlens
Optional[Tensor] = None
nemo_automodel.components.loss.mtp.PipelineCausalLMLoss._extract_seq_idx_tail(
output
) -> tuple[typing.Optional[torch.Tensor], object]
staticmethod

Detect and strip a trailing per-microbatch seq_idx from output.

Convention: with MTP enabled the last-stage output is (logits, *mtp_per_depth_h, seq_idx) with an [B, S] int32 tail — dtype alone discriminates.

nemo_automodel.components.loss.mtp.PipelineCausalLMLoss.forward(
output,
labels: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.loss.mtp.calculate_mtp_loss(
loss_fn,
mtp_per_depth_h: list[torch.Tensor] | None = None,
mtp_per_depth_logits: list[torch.Tensor] | None = None,
labels: torch.Tensor,
model: torch.nn.Module,
scaling_factor: float = 0.1,
num_label_tokens: typing.Optional[int] = None,
ignore_index: int = -100,
cu_seqlens: typing.Optional[torch.Tensor] = None,
seq_idx: typing.Optional[torch.Tensor] = None,
lm_weight: typing.Optional[torch.Tensor] = None
) -> torch.Tensor

Compute the DeepSeek-V3 Multi-Token Prediction auxiliary loss.

Each depth’s CE is dispatched through :func:calculate_loss with the same loss class as the main path, so MTP inherits FusedLinearCrossEntropy / MaskedCrossEntropy memory and numerical characteristics.

Parameters:

loss_fn

Configured per-token loss class (same instance the main path uses).

mtp_per_depth_h
list[torch.Tensor] | NoneDefaults to None

Per-depth hidden states from the model’s MTP head, one [B, S, H] tensor per depth.

labels
torch.Tensor

Original (unshifted) labels.

model
nn.Module

The wrapped model; used to fetch the shared LM head when the loss class needs materialized logits (non-FusedLinearCE path).

scaling_factor
floatDefaults to 0.1

Coefficient applied to the summed per-depth CE.

num_label_tokens
Optional[int]Defaults to None

Total non-ignore label tokens (forwarded to the base loss for sum-reduction normalization).

ignore_index
intDefaults to -100

Label value masked out of the CE loss for the trailing k+1 rolled positions at depth k.

cu_seqlens
Optional[torch.Tensor]Defaults to None

Optional cumulative sequence lengths [num_seqs+1] (THD-pack layout). When supplied and seq_idx is not, builds a per-token sub-sequence index via searchsorted. Without packing this can be omitted.

seq_idx
Optional[torch.Tensor]Defaults to None

Optional per-token sub-sequence index [B, S] (or [S]). Equality classes are what matter; absolute values can be any ints. Takes precedence over cu_seqlens. Used to mask label rolls whose source position lies in a different sub-sequence.

lm_weight
Optional[torch.Tensor]Defaults to None

Optional caller-materialized LM-head weight. Supplying this lets the main loss and all MTP depths share one DTensor full_tensor() gather on the FusedLinearCrossEntropy path.

Returns: torch.Tensor

Scalar MTP loss with autograd graph.