nemo_automodel.components.loss.masked_ce
#
Module Contents#
Classes#
API#
- class nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy(
- fp32_upcast: bool = True,
- ignore_index: int = -100,
- reduction: str = 'sum',
Initialization
Masked cross-entropy loss.
- Parameters:
fp32_upcast (bool) – if True it will cast logits to float32 before computing cross entropy. Default: True.
ignore_index (int) – label to ignore in CE calculation. Defaults to -100.
reduction (str) – type of reduction. Defaults to “sum”.
- __call__(
- logits: torch.Tensor,
- labels: torch.Tensor,
- mask: Optional[torch.Tensor] = None,
Compute the masked cross-entropy loss between logits and targets.
If a mask is provided, the loss is computed per element, multiplied by the mask, and then averaged. If no mask is provided, the standard cross-entropy loss is used.
- Parameters:
logits (torch.Tensor) – The predicted logits with shape [batch_size, seq_len, vocab_size] where C is the number of classes.
labels (torch.Tensor) – The ground truth class indices with shape [batch_size, seq_len].
mask (torch.Tensor, optional) – A tensor that masks the loss computation. Items marked with 1 will be used to calculate loss, otherwise ignored. Must be broadcastable to the shape of the loss. Defaults to None.
- Returns:
The computed loss as a scalar tensor.
- Return type:
torch.Tensor