nemo_automodel.components.loss.dllm_loss#

Loss functions for diffusion LLM (dLLM) training.

Both loss classes return :class:DLLMLossOutput so the recipe can handle them uniformly without branching on model type.

Module Contents#

Classes#

DLLMLossOutput

Unified return type for all dLLM loss functions.

MDLMCrossEntropyLoss

Cross-entropy loss for MDLM training.

HybridDiffusionLLMLoss

Combined diffusion + optional AR loss for hybrid diffusion LLM models.

Functions#

_compute_per_token_nll

Compute per-token negative log-likelihood, shape [B, L].

API#

nemo_automodel.components.loss.dllm_loss._compute_per_token_nll(
logits: torch.Tensor,
target_ids: torch.Tensor,
) torch.Tensor#

Compute per-token negative log-likelihood, shape [B, L].

class nemo_automodel.components.loss.dllm_loss.DLLMLossOutput#

Bases: typing.NamedTuple

Unified return type for all dLLM loss functions.

.. attribute:: total_loss

Loss used for backward (may include AR component).

.. attribute:: dllm_loss

Pure diffusion loss for logging/metrics.

total_loss: torch.Tensor#

None

dllm_loss: torch.Tensor#

None

class nemo_automodel.components.loss.dllm_loss.MDLMCrossEntropyLoss(fp32_upcast: bool = True)#

Bases: torch.nn.Module

Cross-entropy loss for MDLM training.

Matches the reference dllm framework (dllm/core/trainers/mdlm.py):

.. math:: \text{loss} = \frac{\sum_{i \in \text{masked}} \text{CE}_i \cdot w(t)}{\sum \text{maskable}}

where :math:w(t) = 1/t for the scheduler weight type (linear schedule).

Initialization

forward(
logits: torch.Tensor,
target_ids: torch.Tensor,
noise_mask: torch.Tensor,
p_mask: torch.Tensor,
loss_mask: torch.Tensor,
loss_mask_ar: Optional[torch.Tensor] = None,
num_diffusion_tokens: Optional[int] = None,
num_ar_tokens: Optional[int] = None,
causal_logits: Optional[torch.Tensor] = None,
) nemo_automodel.components.loss.dllm_loss.DLLMLossOutput#

Compute the MDLM cross-entropy loss.

Parameters:
  • logits – Model output logits, shape [B, L, V].

  • target_ids – Clean (uncorrupted) token IDs, shape [B, L].

  • noise_mask – Boolean mask of corrupted positions, shape [B, L].

  • p_mask – Per-position masking probability, shape [B, L].

  • loss_mask – Supervised positions mask, shape [B, L].

  • num_diffusion_tokens – If provided, used for global normalization (total supervised tokens across all grad-acc microbatches).

Returns:

class:

DLLMLossOutput where total_loss == dllm_loss.

class nemo_automodel.components.loss.dllm_loss.HybridDiffusionLLMLoss(alpha: float = 1.0, fp32_upcast: bool = True)#

Bases: torch.nn.Module

Combined diffusion + optional AR loss for hybrid diffusion LLM models.

Used by Nemotron-Labs-Diffusion. The diffusion component computes MDLM-style loss at noise-masked positions, weighted by 1/p_mask. An optional autoregressive (AR) component adds standard cross-entropy at AR positions (the causal branch of model output).

Total loss = alpha * diffusion_loss + ar_loss.

Initialization

Initialize the hybrid loss.

Parameters:
  • alpha – Weight for the diffusion loss component.

  • fp32_upcast – If True, upcast logits to float32 for numerical stability.

forward(
logits: torch.Tensor,
target_ids: torch.Tensor,
noise_mask: torch.Tensor,
p_mask: torch.Tensor,
loss_mask: torch.Tensor,
loss_mask_ar: Optional[torch.Tensor] = None,
num_diffusion_tokens: Optional[int] = None,
num_ar_tokens: Optional[int] = None,
causal_logits: Optional[torch.Tensor] = None,
) nemo_automodel.components.loss.dllm_loss.DLLMLossOutput#

Compute the hybrid diffusion + AR loss.

Parameters:
  • logits – Model output logits, shape [B, L, V] or [B, L+L_ar, V] if the model produces both diffusion and AR logits in a single concatenated tensor (legacy path).

  • target_ids – Clean token IDs, shape [B, L].

  • noise_mask – Boolean mask of corrupted positions, shape [B, L].

  • p_mask – Per-position masking probability, shape [B, L].

  • loss_mask – Diffusion loss mask (supervised positions), shape [B, L].

  • loss_mask_ar – AR loss mask, shape [B, L]. If None, no AR loss.

  • num_diffusion_tokens – Total diffusion label tokens for normalization.

  • num_ar_tokens – Total AR label tokens for normalization.

  • causal_logits – Optional separate AR logits, shape [B, L, V]. When provided, avoids the concat/split of the legacy layout.

Returns:

class:

DLLMLossOutput with combined total_loss and the pure (alpha-weighted) diffusion loss exposed as dllm_loss.