nemo_automodel.components.loss.dllm_loss#
Loss functions for diffusion LLM (dLLM) training.
Both loss classes return :class:DLLMLossOutput so the recipe can handle them
uniformly without branching on model type.
Module Contents#
Classes#
Unified return type for all dLLM loss functions. |
|
Cross-entropy loss for MDLM training. |
|
Combined diffusion + optional AR loss for hybrid diffusion LLM models. |
Functions#
Compute per-token negative log-likelihood, shape |
API#
- nemo_automodel.components.loss.dllm_loss._compute_per_token_nll(
- logits: torch.Tensor,
- target_ids: torch.Tensor,
Compute per-token negative log-likelihood, shape
[B, L].
- class nemo_automodel.components.loss.dllm_loss.DLLMLossOutput#
Bases:
typing.NamedTupleUnified return type for all dLLM loss functions.
.. attribute:: total_loss
Loss used for backward (may include AR component).
.. attribute:: dllm_loss
Pure diffusion loss for logging/metrics.
- total_loss: torch.Tensor#
None
- dllm_loss: torch.Tensor#
None
- class nemo_automodel.components.loss.dllm_loss.MDLMCrossEntropyLoss(fp32_upcast: bool = True)#
Bases:
torch.nn.ModuleCross-entropy loss for MDLM training.
Matches the reference dllm framework (
dllm/core/trainers/mdlm.py):.. math:: \text{loss} = \frac{\sum_{i \in \text{masked}} \text{CE}_i \cdot w(t)}{\sum \text{maskable}}
where :math:
w(t) = 1/tfor theschedulerweight type (linear schedule).Initialization
- forward(
- logits: torch.Tensor,
- target_ids: torch.Tensor,
- noise_mask: torch.Tensor,
- p_mask: torch.Tensor,
- loss_mask: torch.Tensor,
- loss_mask_ar: Optional[torch.Tensor] = None,
- num_diffusion_tokens: Optional[int] = None,
- num_ar_tokens: Optional[int] = None,
- causal_logits: Optional[torch.Tensor] = None,
Compute the MDLM cross-entropy loss.
- Parameters:
logits – Model output logits, shape
[B, L, V].target_ids – Clean (uncorrupted) token IDs, shape
[B, L].noise_mask – Boolean mask of corrupted positions, shape
[B, L].p_mask – Per-position masking probability, shape
[B, L].loss_mask – Supervised positions mask, shape
[B, L].num_diffusion_tokens – If provided, used for global normalization (total supervised tokens across all grad-acc microbatches).
- Returns:
- class:
DLLMLossOutputwheretotal_loss == dllm_loss.
- class nemo_automodel.components.loss.dllm_loss.HybridDiffusionLLMLoss(alpha: float = 1.0, fp32_upcast: bool = True)#
Bases:
torch.nn.ModuleCombined diffusion + optional AR loss for hybrid diffusion LLM models.
Used by Nemotron-Labs-Diffusion. The diffusion component computes MDLM-style loss at noise-masked positions, weighted by
1/p_mask. An optional autoregressive (AR) component adds standard cross-entropy at AR positions (the causal branch of model output).Total loss = alpha * diffusion_loss + ar_loss.
Initialization
Initialize the hybrid loss.
- Parameters:
alpha – Weight for the diffusion loss component.
fp32_upcast – If True, upcast logits to float32 for numerical stability.
- forward(
- logits: torch.Tensor,
- target_ids: torch.Tensor,
- noise_mask: torch.Tensor,
- p_mask: torch.Tensor,
- loss_mask: torch.Tensor,
- loss_mask_ar: Optional[torch.Tensor] = None,
- num_diffusion_tokens: Optional[int] = None,
- num_ar_tokens: Optional[int] = None,
- causal_logits: Optional[torch.Tensor] = None,
Compute the hybrid diffusion + AR loss.
- Parameters:
logits – Model output logits, shape
[B, L, V]or[B, L+L_ar, V]if the model produces both diffusion and AR logits in a single concatenated tensor (legacy path).target_ids – Clean token IDs, shape
[B, L].noise_mask – Boolean mask of corrupted positions, shape
[B, L].p_mask – Per-position masking probability, shape
[B, L].loss_mask – Diffusion loss mask (supervised positions), shape
[B, L].loss_mask_ar – AR loss mask, shape
[B, L]. If None, no AR loss.num_diffusion_tokens – Total diffusion label tokens for normalization.
num_ar_tokens – Total AR label tokens for normalization.
causal_logits – Optional separate AR logits, shape
[B, L, V]. When provided, avoids the concat/split of the legacy layout.
- Returns:
- class:
DLLMLossOutputwith combinedtotal_lossand the pure (alpha-weighted) diffusion loss exposed asdllm_loss.