nemo_automodel.components.loss.dllm_loss

View as Markdown

Loss functions for diffusion LLM (dLLM) training.

All loss classes return :class:DLLMLossOutput so the recipe can handle them uniformly without branching on model type.

Module Contents

Classes

NameDescription
BlockDiffusionCrossEntropyLossFlat cross-entropy loss for block-diffusion (diffusion_gemma) training.
DFlashDecayLossPosition-decay cross-entropy loss for DFlash draft model training.
DLLMLossOutputUnified return type for all dLLM loss functions.
HybridDiffusionLLMLossCombined diffusion + optional AR loss for hybrid diffusion LLM models.
MDLMCrossEntropyLossCross-entropy loss for MDLM training.

Functions

NameDescription
_compute_per_token_nllCompute per-token negative log-likelihood, shape [B, L].
encoder_ar_lossAutoregressive next-token CE on the encoder’s causal logits.

API

class nemo_automodel.components.loss.dllm_loss.BlockDiffusionCrossEntropyLoss(
fp32_upcast: bool = True
)

Bases: Module

Flat cross-entropy loss for block-diffusion (diffusion_gemma) training.

The diffusion_gemma checkpoint uses uniform random-token (D3PM-uniform) corruption, not absorbing [MASK]. Its loss is plain mean cross-entropy over all supervised canvas positions (corrupted AND uncorrupted): the loss support is the full selected canvas (target_mask = canvas_mask), which is NOT noise-gated. noise_mask is accepted (for diagnostics) but does NOT gate the loss support:

.. math:: \text{loss} = \frac{\sum_{i \in \text{supervised (canvas)}} \text{CE}_i}{N}

where N is the supervised canvas-token count. There is no 1/p (1/t) reweighting (that is the absorbing-kernel ELBO weight, which does not apply to the uniform kernel) and no autoregressive term. Flatness is a property of this class, not of a caller passing p_mask = 1.

The signature matches :class:MDLMCrossEntropyLoss / :class:HybridDiffusionLLMLoss so the recipe can call it uniformly; the p_mask / causal_logits / loss_mask_ar / num_ar_tokens arguments are accepted but ignored.

nemo_automodel.components.loss.dllm_loss.BlockDiffusionCrossEntropyLoss.forward(
logits: torch.Tensor,
target_ids: torch.Tensor,
noise_mask: torch.Tensor,
p_mask: torch.Tensor,
loss_mask: torch.Tensor,
loss_mask_ar: typing.Optional[torch.Tensor] = None,
num_diffusion_tokens: typing.Optional[int] = None,
num_ar_tokens: typing.Optional[int] = None,
causal_logits: typing.Optional[torch.Tensor] = None
) -> nemo_automodel.components.loss.dllm_loss.DLLMLossOutput

Compute the flat block-diffusion cross-entropy loss.

Parameters:

logits
torch.Tensor

Model output logits over the canvas, shape [B, L, V].

target_ids
torch.Tensor

Clean (uncorrupted) canvas token IDs, shape [B, L].

noise_mask
torch.Tensor

Boolean mask of corrupted positions, shape [B, L].

p_mask
torch.Tensor

Ignored (flat loss has no per-token weight).

loss_mask
torch.Tensor

Supervised positions mask, shape [B, L].

num_diffusion_tokens
Optional[int]Defaults to None

If provided, the global corrupted-token count used as the normalization denominator (summed across grad-acc microbatches). If None, normalizes by the local corrupted count in this microbatch.

Returns: DLLMLossOutput

class:DLLMLossOutput where total_loss == dllm_loss (no AR).

class nemo_automodel.components.loss.dllm_loss.DFlashDecayLoss(
loss_gamma: typing.Optional[float] = 7.0,
use_fused_linear_ce: bool = False,
chunk_size: int = 1024,
normalize: str = 'tokens'
)

Bases: Module

Position-decay cross-entropy loss for DFlash draft model training.

Implements Eq. 4 of the DFlash paper:

.. math:: w_k = \exp!\left(-\frac{k-1}{\gamma}\right), \quad k = 1, \dots, T

where k indexes the predicted positions within a block (k=0 is the clean anchor and is not predicted; k=1 is the first masked position).

Loss is normalised by the sum of effective weights (w_k * block_mask). Pass num_tokens (a global all-reduced count) for normalisation consistent across DP replicas and gradient-accumulation steps.

Paper default γ values (Appendix A.1):

  • block size 16 → γ = 7
  • block size 10 → γ = 5
  • block size 8 → γ = 4

Parameters:

loss_gamma
Optional[float]Defaults to 7.0

Decay parameter γ.

use_fused_linear_ce
boolDefaults to False

When True, compute the per-token NLL with the chunked linear-CE path (:meth:forward_fused) — projects the LM head and runs cross-entropy in position chunks, each wrapped in :func:torch.utils.checkpoint so the full [B, T, vocab] logits tensor is never materialised (peak is one chunk). Keeps large num_blocks_per_sample (e.g. paper-default 512) within memory on full-vocab targets.

We deliberately do NOT use liger_kernel’s LigerFusedLinearCrossEntropyLoss here: its custom autograd Function computes grad_input eagerly in forward and only integrates with FSDP via the model-patching redirection (apply_liger_kernel_to_*). Used standalone under FSDP2 the gradient does not reach the sharded model params (grad_norm 0). The chunked path is plain autograd, so FSDP2 handles it correctly.

chunk_size
intDefaults to 1024

Number of predicted positions per chunk in the chunked linear-CE path. Smaller = lower peak memory, more recompute.

normalize
strDefaults to 'tokens'

Loss denominator. "tokens" (default) divides the decay-weighted sum by num_tokens, a global all-reduced count that keeps the loss consistent across DP replicas and grad-accum. "mean" divides by the effective weight sum (w_k * block_mask).sum() for a per-call decay-weighted mean.

loss_gamma
Optional[float]Defaults to 7.0

Decay parameter γ. None disables decay (all predicted positions weighted equally).

chunk_size
= int(chunk_size)
loss_gamma
= None if loss_gamma is None else float(loss_gamma)
use_fused_linear_ce
= bool(use_fused_linear_ce)
nemo_automodel.components.loss.dllm_loss.DFlashDecayLoss._chunk_nll(
hidden_chunk: torch.Tensor,
lm_head_weight: torch.Tensor,
lm_head_bias: typing.Optional[torch.Tensor],
target_chunk: torch.Tensor
) -> typing.Tuple[torch.Tensor, torch.Tensor]
staticmethod

Project one position chunk; return its per-token NLL and argmax-matches.

Wrapped in :func:torch.utils.checkpoint by the caller, so the [chunk, vocab] logits are recomputed in backward rather than held. The argmax is non-differentiable, so it adds no backward cost.

nemo_automodel.components.loss.dllm_loss.DFlashDecayLoss._decay_weights(
T: int,
block_size: typing.Optional[int],
device,
dtype
) -> torch.Tensor

Eq. 4 weights for T predicted positions, resetting per block.

Returns all-ones (uniform) when loss_gamma is None (decay disabled).

nemo_automodel.components.loss.dllm_loss.DFlashDecayLoss._draft_acc_per_pos(
correct: torch.Tensor,
block_mask: torch.Tensor,
block_size: typing.Optional[int]
) -> typing.Tuple[typing.Optional[torch.Tensor], typing.Optional[torch.Tensor]]
staticmethod

Per-rank (correct, count) sums per block offset k=1..block_size-1.

correct is a [B, T] bool/float tensor of argmax matches and block_mask excludes padding (T = N * (block_size - 1) when block_size is provided). Reshape to [B, N, block_size-1] and sum over (B, N) to get per-offset counts of shape [block_size-1]. Returns (None, None) when block_size is unknown (single-block / legacy path).

nemo_automodel.components.loss.dllm_loss.DFlashDecayLoss._reduce(
token_nll: torch.Tensor,
block_mask: torch.Tensor,
num_tokens: typing.Optional[int],
block_size: typing.Optional[int],
draft_correct_per_pos: typing.Optional[torch.Tensor] = None,
draft_count_per_pos: typing.Optional[torch.Tensor] = None
) -> nemo_automodel.components.loss.dllm_loss.DLLMLossOutput

Apply decay weights + block mask, sum, and normalise.

nemo_automodel.components.loss.dllm_loss.DFlashDecayLoss.forward(
logits: torch.Tensor,
target_ids: torch.Tensor,
block_mask: torch.Tensor,
num_tokens: typing.Optional[int] = None,
block_size: typing.Optional[int] = None
) -> nemo_automodel.components.loss.dllm_loss.DLLMLossOutput

Compute the DFlash decay-weighted loss from pre-computed logits.

Parameters:

logits
torch.Tensor

Draft model logits for the predicted block positions, shape [B, T, V] where T = N * (block_size - 1).

target_ids
torch.Tensor

Ground-truth token IDs, shape [B, T].

block_mask
torch.Tensor

Float/bool valid-position mask, shape [B, T]. Zero entries (padding) are excluded from the loss.

num_tokens
Optional[int]Defaults to None

Optional global token count for loss normalisation.

block_size
Optional[int]Defaults to None

When provided, the decay weights reset at each block boundary so that every block’s first predicted position has weight 1. Required for multi-block training (N > 1).

Returns: DLLMLossOutput

class:DLLMLossOutput.

nemo_automodel.components.loss.dllm_loss.DFlashDecayLoss.forward_fused(
hidden: torch.Tensor,
lm_head_weight: torch.Tensor,
target_ids: torch.Tensor,
block_mask: torch.Tensor,
num_tokens: typing.Optional[int] = None,
block_size: typing.Optional[int] = None,
lm_head_bias: typing.Optional[torch.Tensor] = None
) -> nemo_automodel.components.loss.dllm_loss.DLLMLossOutput

Chunked linear-CE: never materialises the full logits tensor.

Projects the LM head + cross-entropy in chunks of chunk_size predicted positions, each wrapped in :func:torch.utils.checkpoint so the [chunk, vocab] logits are recomputed in backward instead of held — peak logit memory is one chunk, not [B*T, vocab]. Pure autograd, so the gradient flows correctly through FSDP2 (unlike a standalone liger fused-CE Function).

Parameters:

hidden
torch.Tensor

Draft hidden states for the predicted positions, shape [B, T, D] (D = model dim, NOT vocab).

lm_head_weight
torch.Tensor

LM-head projection weight, shape [V, D].

target_ids
torch.Tensor

Ground-truth token IDs, shape [B, T].

block_mask
torch.Tensor

Valid-position mask, shape [B, T].

num_tokens / block_size

as in :meth:forward.

lm_head_bias
Optional[torch.Tensor]Defaults to None

Optional LM-head bias, shape [V].

Returns: DLLMLossOutput

class:DLLMLossOutput.

class nemo_automodel.components.loss.dllm_loss.DLLMLossOutput()

Bases: NamedTuple

Unified return type for all dLLM loss functions.

dllm_loss
Tensor
draft_correct_per_pos
Optional[Tensor] = None
draft_count_per_pos
Optional[Tensor] = None
total_loss
Tensor
class nemo_automodel.components.loss.dllm_loss.HybridDiffusionLLMLoss(
alpha: float = 1.0,
fp32_upcast: bool = True
)

Bases: Module

Combined diffusion + optional AR loss for hybrid diffusion LLM models.

Used by Nemotron-Labs-Diffusion. The diffusion component computes MDLM-style loss at noise-masked positions, weighted by 1/p_mask. An optional autoregressive (AR) component adds standard cross-entropy at AR positions (the causal branch of model output).

Total loss = alpha * diffusion_loss + ar_loss.

nemo_automodel.components.loss.dllm_loss.HybridDiffusionLLMLoss.forward(
logits: torch.Tensor,
target_ids: torch.Tensor,
noise_mask: torch.Tensor,
p_mask: torch.Tensor,
loss_mask: torch.Tensor,
loss_mask_ar: typing.Optional[torch.Tensor] = None,
num_diffusion_tokens: typing.Optional[int] = None,
num_ar_tokens: typing.Optional[int] = None,
causal_logits: typing.Optional[torch.Tensor] = None
) -> nemo_automodel.components.loss.dllm_loss.DLLMLossOutput

Compute the hybrid diffusion + AR loss.

Parameters:

logits
torch.Tensor

Model output logits, shape [B, L, V] or [B, L+L_ar, V] if the model produces both diffusion and AR logits in a single concatenated tensor (legacy path).

target_ids
torch.Tensor

Clean token IDs, shape [B, L].

noise_mask
torch.Tensor

Boolean mask of corrupted positions, shape [B, L].

p_mask
torch.Tensor

Per-position masking probability, shape [B, L].

loss_mask
torch.Tensor

Diffusion loss mask (supervised positions), shape [B, L].

loss_mask_ar
Optional[torch.Tensor]Defaults to None

AR loss mask, shape [B, L]. If None, no AR loss.

num_diffusion_tokens
Optional[int]Defaults to None

Total diffusion label tokens for normalization.

num_ar_tokens
Optional[int]Defaults to None

Total AR label tokens for normalization.

causal_logits
Optional[torch.Tensor]Defaults to None

Optional separate AR logits, shape [B, L, V]. When provided, avoids the concat/split of the legacy layout.

Returns: DLLMLossOutput

class:DLLMLossOutput with combined total_loss and the pure

class nemo_automodel.components.loss.dllm_loss.MDLMCrossEntropyLoss(
fp32_upcast: bool = True
)

Bases: Module

Cross-entropy loss for MDLM training.

Matches the reference dllm framework (dllm/core/trainers/mdlm.py):

.. math:: \text{loss} = \frac{\sum_{i \in \text{masked}} \text{CE}_i \cdot w(t)}{\sum \text{maskable}}

where :math:w(t) = 1/t for the scheduler weight type (linear schedule).

nemo_automodel.components.loss.dllm_loss.MDLMCrossEntropyLoss.forward(
logits: torch.Tensor,
target_ids: torch.Tensor,
noise_mask: torch.Tensor,
p_mask: torch.Tensor,
loss_mask: torch.Tensor,
loss_mask_ar: typing.Optional[torch.Tensor] = None,
num_diffusion_tokens: typing.Optional[int] = None,
num_ar_tokens: typing.Optional[int] = None,
causal_logits: typing.Optional[torch.Tensor] = None
) -> nemo_automodel.components.loss.dllm_loss.DLLMLossOutput

Compute the MDLM cross-entropy loss.

Parameters:

logits
torch.Tensor

Model output logits, shape [B, L, V].

target_ids
torch.Tensor

Clean (uncorrupted) token IDs, shape [B, L].

noise_mask
torch.Tensor

Boolean mask of corrupted positions, shape [B, L].

p_mask
torch.Tensor

Per-position masking probability, shape [B, L].

loss_mask
torch.Tensor

Supervised positions mask, shape [B, L].

num_diffusion_tokens
Optional[int]Defaults to None

If provided, used for global normalization (total supervised tokens across all grad-acc microbatches).

Returns: DLLMLossOutput

class:DLLMLossOutput where total_loss == dllm_loss.

nemo_automodel.components.loss.dllm_loss._compute_per_token_nll(
logits: torch.Tensor,
target_ids: torch.Tensor
) -> torch.Tensor

Compute per-token negative log-likelihood, shape [B, L].

nemo_automodel.components.loss.dllm_loss.encoder_ar_loss(
encoder_logits: torch.Tensor,
input_ids: torch.Tensor,
valid_mask: typing.Optional[torch.Tensor] = None,
num_tokens: typing.Optional[int] = None
) -> torch.Tensor

Autoregressive next-token CE on the encoder’s causal logits.

The co-trained encoder loss for diffusion_gemma SFT: a standard causal LM cross-entropy over the clean full sequence, scored where both the current and next position are valid (non-pad).

Parameters:

encoder_logits
torch.Tensor

Encoder logits over the clean sequence, [B, S, V].

input_ids
torch.Tensor

Clean token IDs, [B, S].

valid_mask
Optional[torch.Tensor]Defaults to None

Boolean non-pad mask [B, S]. If None, all positions count.

num_tokens
Optional[int]Defaults to None

Optional global denominator (summed across grad-acc microbatches); defaults to the local valid next-token count.

Returns: torch.Tensor

Scalar AR loss (mean CE over valid next-token positions).