core.fusions.fused_cross_entropy#
Module Contents#
Classes#
Functions#
Calculates the maximum logits of the predicted tokens. |
|
Calculates the predicted logits for the tokens. |
|
Calculates the final cross entropy loss for the tokens. |
|
Calculate the logits gradients scaled based on the CE loss |
|
Performs cross entropy loss when logits are split across tensor parallel ranks |
API#
- core.fusions.fused_cross_entropy.calculate_logits_max(
- vocab_parallel_logits: torch.Tensor,
Calculates the maximum logits of the predicted tokens.
- core.fusions.fused_cross_entropy.calculate_predicted_logits(
- vocab_parallel_logits: torch.Tensor,
- target: torch.Tensor,
- logits_max: torch.Tensor,
- vocab_start_index: int,
- vocab_end_index: int,
Calculates the predicted logits for the tokens.
- core.fusions.fused_cross_entropy.calculate_cross_entropy_loss(
- exp_logits: torch.Tensor,
- predicted_logits_sum_exp_logits: torch.Tensor,
Calculates the final cross entropy loss for the tokens.
- core.fusions.fused_cross_entropy.calculate_gradients(
- softmax: torch.Tensor,
- grad_output: torch.Tensor,
- target_mask: torch.Tensor,
- masked_target_1d: torch.Tensor,
Calculate the logits gradients scaled based on the CE loss
- class core.fusions.fused_cross_entropy._VocabParallelCrossEntropy#
Bases:
torch.autograd.Function- static forward(ctx, vocab_parallel_logits, target, tp_group)#
Forward implementation for the cross entropy loss.
- static backward(ctx, grad_output)#
Backward implementation for the cross entropy loss.
- core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy(
- vocab_parallel_logits,
- target,
- tp_group,
Performs cross entropy loss when logits are split across tensor parallel ranks
- Parameters:
vocab_parallel_logits – logits split across tensor parallel ranks dimension is [sequence_length, batch_size, hidden_size]
target – correct vocab ids of dimseion [sequence_length, micro_batch_size]
tp_group – the tensor parallel group over which to all reduce