Megatron Core User Guide

tensor_parallel package

This package contains an implementation for tensor parallelism in transformer models (see Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism and Reducing Activation Recomputation in Large Transformer Models for details).

class core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy

Bases: object

Computes the Cross Entropy Loss splitting the Vocab size across tensor parallel ranks. This implementation is used in both fused and unfused cross entropy implementations

static calculate_cross_entropy_loss(exp_logits: torch.Tensor, predicted_logits: torch.Tensor, sum_exp_logits: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor]

Calculates cross entropy loss.

static calculate_gradients(grad_2d: torch.Tensor, arange_1d: torch.Tensor, masked_target_1d: torch.Tensor, softmax_update: torch.Tensor, grad_input: torch.Tensor, grad_output: torch.Tensor) → torch.Tensor

Calculates gradients.

static calculate_logits_max(vocab_parallel_logits: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor]

Calculates logits_max.

static calculate_predicted_logits(vocab_parallel_logits: torch.Tensor, target: torch.Tensor, logits_max: torch.Tensor, vocab_start_index: int, vocab_end_index: int) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

Calculates predicted logits.

static prepare_gradient_calculation_operands(softmax: torch.Tensor, target_mask: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

Prepare gradient calculation operands.

core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0)

Performs cross entropy loss when logits are split across tensor parallel ranks

  • vocab_parallel_logits – logits split across tensor parallel ranks dimension is [sequence_length, batch_size, vocab_size/num_parallel_ranks]

  • target – correct vocab ids of dimseion [sequence_length, micro_batch_size]

  • label_smoothing – smoothing factor, must be in range [0.0, 1.0) default is no smoothing (=0.0), data, datatype)

Broadcast data from rank zero of each model parallel group to the members of the same model parallel group.

  • keys – list of keys in the data disctionary to be broadcasted

  • data – data dictionary of string keys and cpu tensor values.

  • datatype – torch data type of all tensors in data associated with keys.

