nemo_automodel.components.loss.te_parallel_ce

View as Markdown

Module Contents

Classes

NameDescription
CrossEntropyFunctionThis class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the
TEParallelCrossEntropyTransformerEngine tensor-parallel cross-entropy loss wrapper.

Data

HAVE_DTENSOR

HAVE_TE_PARALLEL_CE

MISSING_TE_PARALLEL_CE_MSG

parallel_cross_entropy

API

class nemo_automodel.components.loss.te_parallel_ce.CrossEntropyFunction()

Bases: Function

This class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted to the dataype of the input.

nemo_automodel.components.loss.te_parallel_ce.CrossEntropyFunction.backward(
ctx,
grad_output
)
staticmethod

The backward pass of the Cross Entropy loss.

Parameters: ctx : The context object with saved tensors. grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.

Returns: tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.

nemo_automodel.components.loss.te_parallel_ce.CrossEntropyFunction.forward(
ctx,
_input,
target,
label_smoothing = 0.0,
reduce_loss = False,
dist_process_group = None,
ignore_idx = -100
)
staticmethod

The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each distributed rank should be (*,V/world_size). Note that each of the ranks should get equal shards along the V dimension.

Parameters: ctx : The context object. _input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size. target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1]. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device. ignore_idx (int): The index for which loss and gradients are made to zero

Returns: tensor: The computed loss.

class nemo_automodel.components.loss.te_parallel_ce.TEParallelCrossEntropy(
ignore_index: int = -100,
reduction: str = 'sum',
tp_group: typing.Optional[torch.distributed.ProcessGroup] = None
)

TransformerEngine tensor-parallel cross-entropy loss wrapper.

nemo_automodel.components.loss.te_parallel_ce.TEParallelCrossEntropy.__call__(
logits: torch.Tensor,
labels: torch.Tensor,
mask: typing.Optional[torch.Tensor] = None,
num_label_tokens: typing.Optional[int] = None
) -> torch.Tensor

Compute parallel cross entropy loss that matches PyTorch’s cross_entropy behavior.

Parameters:

logits
torch.Tensor

Input logits. Shape: [B, T, V]

labels
torch.Tensor

Target labels. Shape: [B, T]

mask
Optional[torch.Tensor]Defaults to None

Mask to apply to the loss. Shape: [B, T]

num_label_tokens
intDefaults to None

The number of non-padding tokens.

Returns: torch.Tensor

Computed loss tensor

nemo_automodel.components.loss.te_parallel_ce.HAVE_DTENSOR = True
nemo_automodel.components.loss.te_parallel_ce.HAVE_TE_PARALLEL_CE = HAVE_TRITON
nemo_automodel.components.loss.te_parallel_ce.MISSING_TE_PARALLEL_CE_MSG = MISSING_TRITON_MSG

Cross Entropy Loss API from NVIDIA’s TransformerEngine, available under the Apache License 2.0: https://github.com/NVIDIA/TransformerEngine

nemo_automodel.components.loss.te_parallel_ce.parallel_cross_entropy = CrossEntropyFunction.apply