core.tensor_parallel.cross_entropy#
Module Contents#
Classes#
Computes the Cross Entropy Loss splitting the Vocab size across tensor parallel ranks. This implementation is used in both fused and unfused cross entropy implementations |
|
Functions#
Performs cross entropy loss when logits are split across tensor parallel ranks |
API#
- class core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy#
Computes the Cross Entropy Loss splitting the Vocab size across tensor parallel ranks. This implementation is used in both fused and unfused cross entropy implementations
- static calculate_logits_max(
- vocab_parallel_logits: torch.Tensor,
Calculates logits_max.
- static calculate_predicted_logits(
- vocab_parallel_logits: torch.Tensor,
- target: torch.Tensor,
- logits_max: torch.Tensor,
- vocab_start_index: int,
- vocab_end_index: int,
Calculates predicted logits.
- static calculate_cross_entropy_loss(
- exp_logits: torch.Tensor,
- predicted_logits: torch.Tensor,
- sum_exp_logits: torch.Tensor,
Calculates cross entropy loss.
- static prepare_gradient_calculation_operands(
- softmax: torch.Tensor,
- target_mask: torch.Tensor,
Prepare gradient calculation operands.
- static calculate_gradients(
- grad_2d: torch.Tensor,
- arange_1d: torch.Tensor,
- masked_target_1d: torch.Tensor,
- softmax_update: torch.Tensor,
- grad_input: torch.Tensor,
- grad_output: torch.Tensor,
Calculates gradients.
- class core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy#
Bases:
torch.autograd.Function- static forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0)#
Vocab parallel cross entropy forward function.
- static backward(ctx, grad_output)#
Vocab parallel cross entropy backward function.
- core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy(
- vocab_parallel_logits,
- target,
- label_smoothing=0.0,
Performs cross entropy loss when logits are split across tensor parallel ranks
- Parameters:
vocab_parallel_logits – logits split across tensor parallel ranks dimension is [sequence_length, batch_size, vocab_size/num_parallel_ranks]
target – correct vocab ids of dimseion [sequence_length, micro_batch_size]
label_smoothing – smoothing factor, must be in range [0.0, 1.0) default is no smoothing (=0.0)