nemo_automodel.components.loss.masked_ce#

Module Contents#

Classes#

API#

class nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy(
fp32_upcast: bool = True,
ignore_index: int = -100,
reduction: str = 'sum',
)[source]#

Initialization

Masked cross-entropy loss.

Parameters:
  • fp32_upcast (bool) – if True it will cast logits to float32 before computing cross entropy. Default: True.

  • ignore_index (int) – label to ignore in CE calculation. Defaults to -100.

  • reduction (str) – type of reduction. Defaults to “sum”.

__call__(
logits: torch.Tensor,
labels: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) torch.Tensor[source]#

Compute the masked cross-entropy loss between logits and targets.

If a mask is provided, the loss is computed per element, multiplied by the mask, and then averaged. If no mask is provided, the standard cross-entropy loss is used.

Parameters:
  • logits (torch.Tensor) – The predicted logits with shape [batch_size, seq_len, vocab_size] where C is the number of classes.

  • labels (torch.Tensor) – The ground truth class indices with shape [batch_size, seq_len].

  • mask (torch.Tensor, optional) – A tensor that masks the loss computation. Items marked with 1 will be used to calculate loss, otherwise ignored. Must be broadcastable to the shape of the loss. Defaults to None.

Returns:

The computed loss as a scalar tensor.

Return type:

torch.Tensor