core.fusions.fused_cross_entropy#

Module Contents#

Classes#

Functions#

calculate_logits_max

Calculates the maximum logits of the predicted tokens.

calculate_predicted_logits

Calculates the predicted logits for the tokens.

calculate_cross_entropy_loss

Calculates the final cross entropy loss for the tokens.

calculate_gradients

Calculate the logits gradients scaled based on the CE loss

fused_vocab_parallel_cross_entropy

Performs cross entropy loss when logits are split across tensor parallel ranks

API#

core.fusions.fused_cross_entropy.calculate_logits_max(
vocab_parallel_logits: torch.Tensor,
) Tuple[torch.Tensor, torch.Tensor]#

Calculates the maximum logits of the predicted tokens.

core.fusions.fused_cross_entropy.calculate_predicted_logits(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
logits_max: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]#

Calculates the predicted logits for the tokens.

core.fusions.fused_cross_entropy.calculate_cross_entropy_loss(
exp_logits: torch.Tensor,
predicted_logits_sum_exp_logits: torch.Tensor,
) Tuple[torch.Tensor, torch.Tensor]#

Calculates the final cross entropy loss for the tokens.

core.fusions.fused_cross_entropy.calculate_gradients(
softmax: torch.Tensor,
grad_output: torch.Tensor,
target_mask: torch.Tensor,
masked_target_1d: torch.Tensor,
) torch.Tensor#

Calculate the logits gradients scaled based on the CE loss

class core.fusions.fused_cross_entropy._VocabParallelCrossEntropy#

Bases: torch.autograd.Function

static forward(ctx, vocab_parallel_logits, target, tp_group)#

Forward implementation for the cross entropy loss.

static backward(ctx, grad_output)#

Backward implementation for the cross entropy loss.

core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy(
vocab_parallel_logits,
target,
tp_group,
)#

Performs cross entropy loss when logits are split across tensor parallel ranks

Parameters:
  • vocab_parallel_logits – logits split across tensor parallel ranks dimension is [sequence_length, batch_size, hidden_size]

  • target – correct vocab ids of dimseion [sequence_length, micro_batch_size]

  • tp_group – the tensor parallel group over which to all reduce