nemo_automodel.components.loss.dllm_loss#

Loss functions for diffusion LLM (dLLM) training.

Both loss classes return :class:DLLMLossOutput so the recipe can handle them uniformly without branching on model type.

Module Contents#

Classes#

DLLMLossOutput

Unified return type for all dLLM loss functions.

MDLMCrossEntropyLoss

Cross-entropy loss for MDLM training.

Functions#

_compute_per_token_nll

Compute per-token negative log-likelihood, shape [B, L].

API#

nemo_automodel.components.loss.dllm_loss._compute_per_token_nll(
logits: torch.Tensor,
target_ids: torch.Tensor,
) torch.Tensor#

Compute per-token negative log-likelihood, shape [B, L].

class nemo_automodel.components.loss.dllm_loss.DLLMLossOutput#

Bases: typing.NamedTuple

Unified return type for all dLLM loss functions.

.. attribute:: total_loss

Loss used for backward (may include AR component).

.. attribute:: dllm_loss

Pure diffusion loss for logging/metrics.

total_loss: torch.Tensor#

None

dllm_loss: torch.Tensor#

None

class nemo_automodel.components.loss.dllm_loss.MDLMCrossEntropyLoss(fp32_upcast: bool = True)#

Bases: torch.nn.Module

Cross-entropy loss for MDLM training.

Matches the reference dllm framework (dllm/core/trainers/mdlm.py):

.. math:: \text{loss} = \frac{\sum_{i \in \text{masked}} \text{CE}_i \cdot w(t)}{\sum \text{maskable}}

where :math:w(t) = 1/t for the scheduler weight type (linear schedule).

Initialization

forward(
logits: torch.Tensor,
target_ids: torch.Tensor,
noise_mask: torch.Tensor,
p_mask: torch.Tensor,
loss_mask: torch.Tensor,
num_diffusion_tokens: Optional[int] = None,
) nemo_automodel.components.loss.dllm_loss.DLLMLossOutput#

Compute the MDLM cross-entropy loss.

Parameters:
  • logits – Model output logits, shape [B, L, V].

  • target_ids – Clean (uncorrupted) token IDs, shape [B, L].

  • noise_mask – Boolean mask of corrupted positions, shape [B, L].

  • p_mask – Per-position masking probability, shape [B, L].

  • loss_mask – Supervised positions mask, shape [B, L].

  • num_diffusion_tokens – If provided, used for global normalization (total supervised tokens across all grad-acc microbatches).

Returns:

class:

DLLMLossOutput where total_loss == dllm_loss.