Source code for nemo_automodel.components.loss.te_parallel_ce

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch

from nemo_automodel.components.loss.triton.te_cross_entropy import (
    HAVE_TRITON,
    cross_entropy_backward,
    cross_entropy_forward,
)
from nemo_automodel.shared.import_utils import MISSING_TRITON_MSG

HAVE_TE_PARALLEL_CE = HAVE_TRITON
MISSING_TE_PARALLEL_CE_MSG = MISSING_TRITON_MSG


"""Cross Entropy Loss API from NVIDIA's TransformerEngine, available under the Apache License 2.0:
https://github.com/NVIDIA/TransformerEngine"""


[docs] class CrossEntropyFunction(torch.autograd.Function): """ This class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted to the dataype of the input. """
[docs] @staticmethod def forward( ctx, _input, target, label_smoothing=0.0, reduce_loss=False, dist_process_group=None, ignore_idx=-100, ): """ The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each distributed rank should be (*,V/world_size). Note that each of the ranks should get equal shards along the V dimension. Parameters: ctx : The context object. _input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size. target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1]. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device. ignore_idx (int): The index for which loss and gradients are made to zero Returns: tensor: The computed loss. """ loss, _input = cross_entropy_forward( _input, target, label_smoothing, reduce_loss, dist_process_group, ignore_idx ) ctx.save_for_backward(_input.detach()) return loss
[docs] @staticmethod def backward(ctx, grad_output): """ The backward pass of the Cross Entropy loss. Parameters: ctx : The context object with saved tensors. grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. Returns: tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. """ (_input,) = ctx.saved_tensors _input = cross_entropy_backward(_input, grad_output) return ( _input, None, None, None, None, None, # Modified original TransformerEngine version to return None for ignore_idx argument )
parallel_cross_entropy = CrossEntropyFunction.apply
[docs] class TEParallelCrossEntropy: def __init__( self, ignore_index: int = -100, reduction: str = "sum", tp_group: Optional[torch.distributed.ProcessGroup] = None, ): """ Cross entropy loss module based on TransformerEngine's parallel cross entropy triton kernel. Args: ignore_index (int): Target value that is ignored when computing the loss. Defaults to -100. reduction (str): Type of reduction ('none', 'mean', 'sum'). Defaults to "mean". tp_group (Optional[torch.distributed.ProcessGroup]): Process group for tensor parallelism. Defaults to None. """ self.ignore_index = ignore_index self.reduction = reduction self.tp_group = tp_group
[docs] def __call__( self, logits: torch.Tensor, labels: torch.Tensor, mask: Optional[torch.Tensor] = None, num_label_tokens: Optional[int] = None, ) -> torch.Tensor: """ Compute parallel cross entropy loss that matches PyTorch's cross_entropy behavior. Args: logits: Input logits. Shape: [B, T, V] labels: Target labels. Shape: [B, T] mask: Mask to apply to the loss. Shape: [B, T] num_label_tokens (int): The number of non-padding tokens. Returns: Computed loss tensor """ if not HAVE_TE_PARALLEL_CE: raise ImportError(MISSING_TE_PARALLEL_CE_MSG) if mask is not None: with torch.no_grad(): if mask.device != labels.device: mask = mask.to(labels.device) labels.masked_fill_(mask == 0, self.ignore_index) del mask reduce_loss = self.reduction == "mean" # Compute TE parallel cross entropy te_loss = parallel_cross_entropy(logits, labels, 0.0, reduce_loss, self.tp_group, self.ignore_index) # Apply reduction if self.reduction == "none" or self.reduction == "mean": return te_loss elif self.reduction == "sum": loss = te_loss.sum() if num_label_tokens is not None: loss = loss / num_label_tokens return loss else: raise ValueError(f"Invalid reduction: {self.reduction}. Must be one of 'none', 'mean', 'sum'")