core.tensor_parallel.cross_entropy#

Module Contents#

Classes#

VocabParallelCrossEntropy

Computes the Cross Entropy Loss splitting the Vocab size across tensor parallel ranks. This implementation is used in both fused and unfused cross entropy implementations

_VocabParallelCrossEntropy

Functions#

vocab_parallel_cross_entropy

Performs cross entropy loss when logits are split across tensor parallel ranks

API#

class core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy#

Computes the Cross Entropy Loss splitting the Vocab size across tensor parallel ranks. This implementation is used in both fused and unfused cross entropy implementations

static calculate_logits_max(
vocab_parallel_logits: torch.Tensor,
) Tuple[torch.Tensor, torch.Tensor]#

Calculates logits_max.

static calculate_predicted_logits(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
logits_max: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]#

Calculates predicted logits.

static calculate_cross_entropy_loss(
exp_logits: torch.Tensor,
predicted_logits: torch.Tensor,
sum_exp_logits: torch.Tensor,
) Tuple[torch.Tensor, torch.Tensor]#

Calculates cross entropy loss.

static prepare_gradient_calculation_operands(
softmax: torch.Tensor,
target_mask: torch.Tensor,
) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]#

Prepare gradient calculation operands.

static calculate_gradients(
grad_2d: torch.Tensor,
arange_1d: torch.Tensor,
masked_target_1d: torch.Tensor,
softmax_update: torch.Tensor,
grad_input: torch.Tensor,
grad_output: torch.Tensor,
) torch.Tensor#

Calculates gradients.

class core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy#

Bases: torch.autograd.Function

static forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0)#

Vocab parallel cross entropy forward function.

static backward(ctx, grad_output)#

Vocab parallel cross entropy backward function.

core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy(
vocab_parallel_logits,
target,
label_smoothing=0.0,
)#

Performs cross entropy loss when logits are split across tensor parallel ranks

Parameters:
  • vocab_parallel_logits – logits split across tensor parallel ranks dimension is [sequence_length, batch_size, vocab_size/num_parallel_ranks]

  • target – correct vocab ids of dimseion [sequence_length, micro_batch_size]

  • label_smoothing – smoothing factor, must be in range [0.0, 1.0) default is no smoothing (=0.0)