nemo_automodel.components.loss.te_parallel_ce
#
Module Contents#
Classes#
This class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted to the dataype of the input. |
|
Data#
Cross Entropy Loss API from NVIDIA’s TransformerEngine, available under the Apache License 2.0: https://github.com/NVIDIA/TransformerEngine |
|
API#
- nemo_automodel.components.loss.te_parallel_ce.HAVE_TE_PARALLEL_CE#
None
- nemo_automodel.components.loss.te_parallel_ce.MISSING_TE_PARALLEL_CE_MSG#
None
Cross Entropy Loss API from NVIDIA’s TransformerEngine, available under the Apache License 2.0: https://github.com/NVIDIA/TransformerEngine
- class nemo_automodel.components.loss.te_parallel_ce.CrossEntropyFunction(*args, **kwargs)[source]#
Bases:
torch.autograd.Function
This class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted to the dataype of the input.
Initialization
- static forward(
- ctx,
- _input,
- target,
- label_smoothing=0.0,
- reduce_loss=False,
- dist_process_group=None,
- ignore_idx=-100,
The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each distributed rank should be (*,V/world_size). Note that each of the ranks should get equal shards along the V dimension.
Parameters: ctx : The context object. _input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size. target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1]. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device. ignore_idx (int): The index for which loss and gradients are made to zero
Returns: tensor: The computed loss.
- static backward(ctx, grad_output)[source]#
The backward pass of the Cross Entropy loss.
Parameters: ctx : The context object with saved tensors. grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
Returns: tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
- nemo_automodel.components.loss.te_parallel_ce.parallel_cross_entropy#
None
- class nemo_automodel.components.loss.te_parallel_ce.TEParallelCrossEntropy(
- ignore_index: int = -100,
- reduction: str = 'sum',
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Initialization
Cross entropy loss module based on TransformerEngine’s parallel cross entropy triton kernel.
- Parameters:
ignore_index (int) – Target value that is ignored when computing the loss. Defaults to -100.
reduction (str) – Type of reduction (‘none’, ‘mean’, ‘sum’). Defaults to “mean”.
tp_group (Optional[torch.distributed.ProcessGroup]) – Process group for tensor parallelism. Defaults to None.
- __call__(
- logits: torch.Tensor,
- labels: torch.Tensor,
- mask: Optional[torch.Tensor] = None,
- num_label_tokens: Optional[int] = None,
Compute parallel cross entropy loss that matches PyTorch’s cross_entropy behavior.
- Parameters:
logits – Input logits. Shape: [B, T, V]
labels – Target labels. Shape: [B, T]
mask – Mask to apply to the loss. Shape: [B, T]
num_label_tokens (int) – The number of non-padding tokens.
- Returns:
Computed loss tensor