nemo_automodel.components.loss.masked_ce
nemo_automodel.components.loss.masked_ce
Module Contents
Classes
API
Bases: Module
Cross-entropy loss that handles ignored or masked target positions.
Compute the masked cross-entropy loss between logits and targets.
If a mask is provided, the loss is computed per element, multiplied by the mask, and then averaged. If no mask is provided, the standard cross-entropy loss is used.
Parameters:
The predicted logits with shape [batch_size, seq_len, vocab_size] where C is the number of classes.
The ground truth class indices with shape [batch_size, seq_len].
A tensor that masks the loss computation. Items marked with 1 will be used to calculate loss, otherwise ignored. Must be broadcastable to the shape of the loss. Defaults to None.
Returns: torch.Tensor
torch.Tensor: The computed loss as a scalar tensor.