nemo_automodel.components.loss.chunked_ce

View as Markdown

Module Contents

Classes

NameDescription
ChunkedCrossEntropyCross-entropy loss computed over sequence chunks.

Functions

NameDescription
compute_cross_entropyComputes the cross-entropy loss between logits and targets.

Data

_compiled_compute_cross_entropy

API

class nemo_automodel.components.loss.chunked_ce.ChunkedCrossEntropy(
chunk_len: int = 32,
compile: bool = True,
ignore_index: int = -100,
reduction: str = 'sum'
)

Bases: Module

Cross-entropy loss computed over sequence chunks.

nemo_automodel.components.loss.chunked_ce.ChunkedCrossEntropy.forward(
logits: torch.Tensor,
labels: torch.Tensor,
mask: typing.Optional[torch.Tensor] = None,
num_label_tokens: typing.Optional[int] = None
) -> torch.Tensor

Computes cross-entropy loss in chunks to handle long sequences more efficiently.

Parameters:

logits
torch.Tensor

Model output logits of shape [batch_size, seq_len, vocab_size].

labels
torch.Tensor

Ground-truth labels of shape [batch_size, seq_len].

mask
torch.TensorDefaults to None

Boolean mask indicating valid positions (1) and positions to ignore (0). Defaults to None.

Returns: torch.Tensor

torch.Tensor: The sum of cross-entropy losses over the sequence.

nemo_automodel.components.loss.chunked_ce.compute_cross_entropy(
logits: torch.Tensor,
targets: torch.Tensor,
ignore_index = -100,
reduction = 'sum'
)

Computes the cross-entropy loss between logits and targets.

Parameters:

logits
torch.Tensor

Model predictions of shape (sequence_length, num_classes).

targets
torch.Tensor

Ground-truth labels of shape (sequence_length,).

ignore_index
intDefaults to -100

Target value that is ignored when computing the loss. Defaults to -100.

Returns:

torch.Tensor: The sum of cross-entropy losses over the sequence.

nemo_automodel.components.loss.chunked_ce._compiled_compute_cross_entropy = None