nemo_automodel.components.loss.te_parallel_ce#

Module Contents#

Classes#

CrossEntropyFunction

This class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted to the dataype of the input.

TEParallelCrossEntropy

Data#

HAVE_TE_PARALLEL_CE

MISSING_TE_PARALLEL_CE_MSG

Cross Entropy Loss API from NVIDIA’s TransformerEngine, available under the Apache License 2.0: https://github.com/NVIDIA/TransformerEngine

parallel_cross_entropy

API#

nemo_automodel.components.loss.te_parallel_ce.HAVE_TE_PARALLEL_CE#

None

nemo_automodel.components.loss.te_parallel_ce.MISSING_TE_PARALLEL_CE_MSG#

None

Cross Entropy Loss API from NVIDIA’s TransformerEngine, available under the Apache License 2.0: https://github.com/NVIDIA/TransformerEngine

class nemo_automodel.components.loss.te_parallel_ce.CrossEntropyFunction(*args, **kwargs)[source]#

Bases: torch.autograd.Function

This class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted to the dataype of the input.

Initialization

static forward(
ctx,
_input,
target,
label_smoothing=0.0,
reduce_loss=False,
dist_process_group=None,
ignore_idx=-100,
)[source]#

The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each distributed rank should be (*,V/world_size). Note that each of the ranks should get equal shards along the V dimension.

Parameters: ctx : The context object. _input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size. target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1]. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device. ignore_idx (int): The index for which loss and gradients are made to zero

Returns: tensor: The computed loss.

static backward(ctx, grad_output)[source]#

The backward pass of the Cross Entropy loss.

Parameters: ctx : The context object with saved tensors. grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.

Returns: tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.

nemo_automodel.components.loss.te_parallel_ce.parallel_cross_entropy#

None

class nemo_automodel.components.loss.te_parallel_ce.TEParallelCrossEntropy(
ignore_index: int = -100,
reduction: str = 'sum',
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)[source]#

Initialization

Cross entropy loss module based on TransformerEngine’s parallel cross entropy triton kernel.

Parameters:
  • ignore_index (int) – Target value that is ignored when computing the loss. Defaults to -100.

  • reduction (str) – Type of reduction (‘none’, ‘mean’, ‘sum’). Defaults to “mean”.

  • tp_group (Optional[torch.distributed.ProcessGroup]) – Process group for tensor parallelism. Defaults to None.

__call__(
logits: torch.Tensor,
labels: torch.Tensor,
mask: Optional[torch.Tensor] = None,
num_label_tokens: Optional[int] = None,
) torch.Tensor[source]#

Compute parallel cross entropy loss that matches PyTorch’s cross_entropy behavior.

Parameters:
  • logits – Input logits. Shape: [B, T, V]

  • labels – Target labels. Shape: [B, T]

  • mask – Mask to apply to the loss. Shape: [B, T]

  • num_label_tokens (int) – The number of non-padding tokens.

Returns:

Computed loss tensor