nemo_automodel.components.loss.te_parallel_ce
nemo_automodel.components.loss.te_parallel_ce
Module Contents
Classes
Data
API
Bases: Function
This class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted to the dataype of the input.
The backward pass of the Cross Entropy loss.
Parameters: ctx : The context object with saved tensors. grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
Returns: tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each distributed rank should be (*,V/world_size). Note that each of the ranks should get equal shards along the V dimension.
Parameters: ctx : The context object. _input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size. target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1]. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device. ignore_idx (int): The index for which loss and gradients are made to zero
Returns: tensor: The computed loss.
TransformerEngine tensor-parallel cross-entropy loss wrapper.
Compute parallel cross entropy loss that matches PyTorch’s cross_entropy behavior.
Parameters:
Input logits. Shape: [B, T, V]
Target labels. Shape: [B, T]
Mask to apply to the loss. Shape: [B, T]
The number of non-padding tokens.
Returns: torch.Tensor
Computed loss tensor
Cross Entropy Loss API from NVIDIA’s TransformerEngine, available under the Apache License 2.0: https://github.com/NVIDIA/TransformerEngine