# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
# IMPORTANT: This Apple software is supplied to you by Apple
# Inc. ("Apple") in consideration of your agreement to the following
# terms, and your use, installation, modification or redistribution of
# this Apple software constitutes acceptance of these terms. If you do
# not agree with these terms, please do not use, install, modify or
# redistribute this Apple software.
# In consideration of your agreement to abide by the following terms, and
# subject to these terms, Apple grants you a personal, non-exclusive
# license, under Apple's copyrights in this original Apple software (the
# "Apple Software"), to use, reproduce, modify and redistribute the Apple
# Software, with or without modifications, in source and/or binary forms;
# provided that if you redistribute the Apple Software in its entirety and
# without modifications, you must retain this notice and the following
# text and disclaimers in all such redistributions of the Apple Software.
# Neither the name, trademarks, service marks or logos of Apple Inc. may
# be used to endorse or promote products derived from the Apple Software
# without specific prior written permission from Apple. Except as
# expressly stated in this notice, no other rights or licenses, express or
# implied, are granted by Apple herein, including but not limited to any
# patent rights that may be infringed by your derivative works or by other
# works in which the Apple Software may be incorporated.
# The Apple Software is provided by Apple on an "AS IS" basis. APPLE
# MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
# THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
# OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
# IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
# MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
# AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
# STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# -------------------------------------------------------------------------------
# SOFTWARE DISTRIBUTED WITH CUT CROSS ENTROPY:
# The Cut Cross Entropy software includes a number of subcomponents with separate
# copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.md.
# -------------------------------------------------------------------------------
import torch
from nemo_automodel.shared.import_utils import MISSING_CUT_CROSS_ENTROPY_MSG
try:
import cut_cross_entropy.tl_utils as tl_utils
from cut_cross_entropy import linear_cross_entropy
HAVE_CUT_CROSS_ENTROPY = True
except ImportError:
HAVE_CUT_CROSS_ENTROPY = False
[docs]
def new_is_triton_greater_or_equal(version_str):
"""
Check if pytorch-triton version is greater than or equal to the specified version.
Args:
version_str: Version string to check
Returns:
bool: True if pytorch-triton version >= specified version
"""
import pkg_resources
try:
pytorch_triton_version = pkg_resources.get_distribution("pytorch-triton").version
current = pkg_resources.parse_version(pytorch_triton_version)
required = pkg_resources.parse_version(version_str)
print(f"Current pytorch-triton version: {pytorch_triton_version}, Required triton version: {version_str}")
return current >= required
except pkg_resources.DistributionNotFound:
print("pytorch-triton not found")
return False
[docs]
def new_is_triton_greater_or_equal_3_2_0():
"""
Check if pytorch-triton version is greater than or equal to 3.1.0.
Returns:
bool: True if pytorch-triton version >= 3.1.0
"""
return new_is_triton_greater_or_equal("3.1.0")
if HAVE_CUT_CROSS_ENTROPY:
# Apply the monkey patches
tl_utils.is_triton_greater_or_equal = new_is_triton_greater_or_equal
tl_utils.is_triton_greater_or_equal_3_2_0 = new_is_triton_greater_or_equal_3_2_0
[docs]
def fused_linear_cross_entropy(
hidden_states: torch.Tensor,
lm_weight: torch.Tensor,
labels: torch.Tensor,
num_items_in_batch: int = None,
ignore_index: int = -100,
reduction: str = "mean",
logit_softcapping: float = 0,
accuracy_threshold: str = "auto",
):
"""
Compute fused linear cross entropy loss that matches PyTorch's cross_entropy behavior.
Args:
hidden_states: Input hidden states
lm_weight: Weight matrix for linear transformation
labels: Target labels
num_items_in_batch: Number of valid tokens (where labels != ignore_index)
ignore_index: Value to ignore in labels (default: -100)
reduction: Reduction method ('mean' or 'sum')
logit_softcapping: Value for softcapping logits (0 means no capping)
accuracy_threshold: Threshold for accuracy computation
"""
if not HAVE_CUT_CROSS_ENTROPY:
raise ImportError(MISSING_CUT_CROSS_ENTROPY_MSG)
# First compute loss with sum reduction to handle normalization ourselves
if logit_softcapping == 0:
logit_softcapping = None
# Compute loss with shift=False to match PyTorch behavior
# Set filter_eps=None to avoid any token filtering
loss = linear_cross_entropy(
hidden_states,
lm_weight,
targets=labels,
ignore_index=ignore_index,
softcap=logit_softcapping,
reduction="sum", # Use sum reduction to handle normalization ourselves
shift=False, # Match PyTorch behavior
filter_eps=None, # No token filtering
)
# Match PyTorch's cross_entropy behavior:
# For mean reduction, divide by number of valid tokens
if reduction == "mean":
if num_items_in_batch is None:
num_items_in_batch = torch.sum(labels != ignore_index).item()
loss = loss / num_items_in_batch
return loss