nemo_automodel.components.loss.te_parallel_ce#

Module Contents#

Classes#

CrossEntropyFunction

This class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted to the dataype of the input.

TEParallelCrossEntropy

Data#

HAVE_TE_PARALLEL_CE

MISSING_TE_PARALLEL_CE_MSG

Cross Entropy Loss API from NVIDIA’s TransformerEngine, available under the Apache License 2.0: https://github.com/NVIDIA/TransformerEngine

parallel_cross_entropy

API#

nemo_automodel.components.loss.te_parallel_ce.HAVE_TE_PARALLEL_CE#

None

nemo_automodel.components.loss.te_parallel_ce.MISSING_TE_PARALLEL_CE_MSG#

None

Cross Entropy Loss API from NVIDIA’s TransformerEngine, available under the Apache License 2.0: https://github.com/NVIDIA/TransformerEngine

class nemo_automodel.components.loss.te_parallel_ce.CrossEntropyFunction#

Bases: torch.autograd.Function

This class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted to the dataype of the input.

static forward(
ctx,
_input,
target,
label_smoothing=0.0,
reduce_loss=False,
dist_process_group=None,
ignore_idx=-100,
)#

The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each distributed rank should be (*,V/world_size). Note that each of the ranks should get equal shards along the V dimension.

Parameters: ctx : The context object. _input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size. target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1]. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device. ignore_idx (int): The index for which loss and gradients are made to zero

Returns: tensor: The computed loss.

static backward(ctx, grad_output)#

The backward pass of the Cross Entropy loss.

Parameters: ctx : The context object with saved tensors. grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.

Returns: tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.

nemo_automodel.components.loss.te_parallel_ce.parallel_cross_entropy#

None

class nemo_automodel.components.loss.te_parallel_ce.TEParallelCrossEntropy(
ignore_index: int = -100,
reduction: str = 'sum',
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Initialization

Cross entropy loss module based on TransformerEngine’s parallel cross entropy triton kernel.

Parameters:
  • ignore_index (int) – Target value that is ignored when computing the loss. Defaults to -100.

  • reduction (str) – Type of reduction (‘none’, ‘mean’, ‘sum’). Defaults to “mean”.

  • tp_group (Optional[torch.distributed.ProcessGroup]) – Process group for tensor parallelism. Defaults to None.

__call__(
logits: torch.Tensor,
labels: torch.Tensor,
mask: Optional[torch.Tensor] = None,
num_label_tokens: Optional[int] = None,
) torch.Tensor#

Compute parallel cross entropy loss that matches PyTorch’s cross_entropy behavior.

Parameters:
  • logits – Input logits. Shape: [B, T, V]

  • labels – Target labels. Shape: [B, T]

  • mask – Mask to apply to the loss. Shape: [B, T]

  • num_label_tokens (int) – The number of non-padding tokens.

Returns:

Computed loss tensor