nemo_automodel.components.loss.masked_ce

View as Markdown

Module Contents

Classes

NameDescription
MaskedCrossEntropyCross-entropy loss that handles ignored or masked target positions.

API

class nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy(
fp32_upcast: bool = True,
ignore_index: int = -100,
reduction: str = 'sum'
)

Bases: Module

Cross-entropy loss that handles ignored or masked target positions.

nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy.forward(
logits: torch.Tensor,
labels: torch.Tensor,
mask: typing.Optional[torch.Tensor] = None,
num_label_tokens: typing.Optional[int] = None
) -> torch.Tensor

Compute the masked cross-entropy loss between logits and targets.

If a mask is provided, the loss is computed per element, multiplied by the mask, and then averaged. If no mask is provided, the standard cross-entropy loss is used.

Parameters:

logits
torch.Tensor

The predicted logits with shape [batch_size, seq_len, vocab_size] where C is the number of classes.

labels
torch.Tensor

The ground truth class indices with shape [batch_size, seq_len].

mask
torch.TensorDefaults to None

A tensor that masks the loss computation. Items marked with 1 will be used to calculate loss, otherwise ignored. Must be broadcastable to the shape of the loss. Defaults to None.

Returns: torch.Tensor

torch.Tensor: The computed loss as a scalar tensor.