nemo_automodel.loss.masked_ce#

Module Contents#

Functions#

masked_cross_entropy

Compute the masked cross-entropy loss between logits and targets.

API#

nemo_automodel.loss.masked_ce.masked_cross_entropy(
logits,
targets,
mask=None,
fp32_upcast=True,
ignore_index=-100,
reduction='mean',
)[source]#

Compute the masked cross-entropy loss between logits and targets.

If a mask is provided, the loss is computed per element, multiplied by the mask, and then averaged. If no mask is provided, the standard cross-entropy loss is used.

Parameters:
  • logits (torch.Tensor) – The predicted logits with shape (N, C) where C is the number of classes.

  • targets (torch.Tensor) – The ground truth class indices with shape (N,).

  • mask (torch.Tensor, optional) – A tensor that masks the loss computation. Items marked with 1 will be used to calculate loss, otherwise ignored. Must be broadcastable to the shape of the loss. Defaults to None.

  • fp32_upcast (bool, optional) – if True it will cast logits to float32 before computing

  • Default (cross entropy.) – True.

  • ignore_index (int) – label to ignore in CE calculation. Defaults to -100.

  • reduction (str) – type of reduction. Defaults to β€œmean”.

Returns:

The computed loss as a scalar tensor.

Return type:

torch.Tensor