NVIDIA Megatron-Core
Developer Guide (Latest)

tensor_parallel package

This package contains an implementation for tensor parallelism in transformer models (see Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism and Reducing Activation Recomputation in Large Transformer Models for details).

core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0)

Performs cross entropy loss when logits are split across tensor parallel ranks

Parameters
  • vocab_parallel_logits – logits split across tensor parallel ranks dimension is [sequence_length, batch_size, hidden_size]

  • target – correct vocab ids of dimseion [sequence_length, micro_batch_size]

  • lobal_smoothing – smoothing factor, must be in range [0.0, 1.0) default is no smoothing (=0.0)

core.tensor_parallel.data.broadcast_data(keys, data, datatype)

Broadcast data from rank zero of each model parallel group to the members of the same model parallel group.

Parameters
  • keys – list of keys in the data disctionary to be broadcasted

  • data – data dictionary of string keys and cpu tensor values.

  • datatype – torch data type of all tensors in data associated with keys.

class core.tensor_parallel.layers.ColumnParallelLinear(*args: Any, **kwargs: Any)

Bases: torch.nn.Module

Linear layer with column parallelism.

The linear layer is defined as Y = XA + b. A is parallelized along its second dimension as A = [A_1, …, A_p].

Parameters
  • input_size – first dimension of matrix A.

  • output_size – second dimension of matrix A.

  • bias – If true, add bias

  • gather_output – If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i

  • init_method – method to initialize weights. Note that bias is always set to zero.

  • stride – For the strided linear layers.

  • keep_master_weight_for_test – This was added for testing and should be set to False. It returns the master weights used for initialization.

  • skip_bias_add – If True, do not add the bias term, instead return it to be added by the caller. This enables performance optimations where bias can be fused with other elementwise operations.

  • skip_weight_param_allocation – If True, weight parameter is not allocated and must be passed as a keyword argument weight during the forward pass. Note that this does not affect bias, which will be allocated if bias is True. Defaults to False.

  • is_expert – If True, the layer is treated as an MoE expert layer.

  • config – ModelParallelConfig object

  • tp_comm_buffer_name – Communication buffer name is not used in non-Transformer-Engine modules.

forward(input_: torch.Tensor, weight: Optional[torch.Tensor] = None)

Forward of ColumnParallelLinear

Parameters
  • input – 3D tensor whose order of dimension is [sequence, batch, hidden]

  • weight (optional) – weight tensor to use, compulsory when skip_weight_param_allocation is True.

Returns

  • output

  • bias

get_extra_state() → None

Keep compatibility with TE state dict.

set_extra_state(state: Any)

Extra state is ignored

sharded_state_dict(prefix='', sharded_offsets=())

Sharding along axis 0, bias sharded

class core.tensor_parallel.layers.LinearWithFrozenWeight(*args: Any, **kwargs: Any)

Bases: torch.autograd.Function

Linear operator that does not calculate gradient for weight. This op and LinearWithGradAccumulationAndAsyncCommunication performs mathematically-identical forward and DGRAD.

Conceptually this op is the same as torch.nn.functional.linear with weight.requires_grad==False, but in experiments they are not identical mathematically.

static backward(ctx, grad_output)

static forward(ctx, input, weight, bias)

class core.tensor_parallel.layers.LinearWithGradAccumulationAndAsyncCommunication(*args: Any, **kwargs: Any)

Bases: torch.autograd.Function

See linear_with_grad_accumulation_and_async_allreduce

static backward(ctx, grad_output)

static forward(ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce, sequence_parallel)

class core.tensor_parallel.layers.RowParallelLinear(*args: Any, **kwargs: Any)

Bases: torch.nn.Module

Linear layer with row parallelism.

The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, …, X_p]

Parameters
  • input_size – first dimension of matrix A.

  • output_size – second dimension of matrix A.

  • bias – If true, add bias. Note that bias is not parallelized.

  • input_is_parallel – If true, we assume that the input is already split across the GPUs and we do not split again.

  • init_method – method to initialize weights. Note that bias is always set to zero.

  • stride – For the strided linear layers.

  • keep_master_weight_for_test – This was added for testing and should be set to False. It returns the master weights used for initialization.

  • skip_bias_add – If True, do not add the bias term, instead return it to be added by the caller. This enables performance optimations where bias can be fused with other elementwise operations.

  • is_expert – If True, the layer is treated as an MoE expert layer

  • tp_comm_buffer_name – Communication buffer name. Not used in non-Transformer-Engine modules.

  • config – ModelParallelConfig object

forward(input_)

Forward of RowParallelLinear

Parameters

input – 3D tensor whose order of dimension is [sequence, batch, hidden]

Returns

  • output

  • bias

get_extra_state() → None

Keep compatibility with TE state dict.

set_extra_state(state: Any)

Extra state is ignored

sharded_state_dict(prefix='', sharded_offsets=())

Sharding along axis 1, bias not sharded

class core.tensor_parallel.layers.VocabParallelEmbedding(*args: Any, **kwargs: Any)

Bases: torch.nn.Module

Embedding parallelized in the vocabulary dimension.

This is mainly adapted from torch.nn.Embedding and all the default values are kept. :param num_embeddings: vocabulary size. :param embedding_dim: size of hidden state.

Keyword Arguments

config – A megatron.core.ModelParallelConfig object

forward(input_)

sharded_state_dict(prefix: str = '', sharded_offsets: Tuple[Tuple[int, int, int]] = ()) → Dict[str, Any]

Non-default implementation for embeddings due to allow_shape_mismatch param

core.tensor_parallel.layers.copy_tensor_model_parallel_attributes(destination_tensor, source_tensor)

core.tensor_parallel.layers.linear_with_frozen_weight(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], gradient_accumulation_fusion: bool, async_grad_allreduce: bool, sequence_parallel: bool) → torch.Tensor

Linear layer execution with weight.requires_grad == False.

This function handles linear layers with weight frozen (untrainable). In the forward, it only saves weight and does not save input activations. In the backward, it does not perform weight gradient calculation, or weight gradient allreduce.

Arguments:

input (torch.Tensor required): input like torch.nn.functional.linear

weight (torch.Tensor required): weight like torch.nn.functional.linear

bias (torch.Tensor optional): bias like torch.nn.functional.linear

gradient_accumulation_fusion (bool required): dummy argument, used to keep the API unified between all forward implementation functions.

async_grad_allreduce (bool required): dummy argument, used to keep the API unified between all forward implementation functions.

sequence_parallel (bool required): Indicates that sequence

parallelism is used and thus in the forward pass the input is all gathered, and the backward pass the input gradients are reduce scattered.

core.tensor_parallel.layers.linear_with_grad_accumulation_and_async_allreduce(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], gradient_accumulation_fusion: bool, async_grad_allreduce: bool, sequence_parallel: bool) → torch.Tensor

Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop.

This has the option to accumulate the result of backprop calculation into an existing gradient buffer, preventing the need to do an additional addition kernel after the gradient calculation.

Additionally, the tensor parallel all reduce of the input gradients can be done asynchronously with the calculation of the weight gradients.

In the case of sequence parallelism, the reduce scatter of the input gradients is done asynchronously with the calcluation of the weight gradients.

Use of this module requires that the environment variable CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective operations, noted in the code, that should be scheduled before compute kernels to overlap the communication with the computation, which is necessary for a speedup but not for correctness so that ordering isn’t imposed by the scheduler. Setting CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled in the order they are called.

Arguments:

input (torch.Tensor required): input like torch.nn.functional.linear

weight (torch.Tensor required): weight like torch.nn.functional.linear

bias (torch.Tensor optional): bias like torch.nn.functional.linear

gradient_accumulation_fusion (bool required): Perform the gradient

accumulation fusion, requires the custom CUDA extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with –cpp_ext and –cuda_ext. For example: “pip install –global-option=”–cpp_ext” –global-option=”–cuda_ext .” ” Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion.”

async_grad_allreduce (bool required): Do the allreduce of input

gradients asyncronously with the computation of weight gradients. If sequence_parallel is True, this must be False, as no all reduce is performed.

sequence_parallel (bool required): Indicates that sequence

parallelism is used and thus in the forward pass the input is all gathered, and the backward pass the input gradients are reduce scattered.

core.tensor_parallel.layers.param_is_not_tensor_parallel_duplicate(param)

core.tensor_parallel.layers.set_defaults_if_not_set_tensor_model_parallel_attributes(tensor)

core.tensor_parallel.layers.set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride)

core.tensor_parallel.mappings.copy_to_tensor_model_parallel_region(input_)

core.tensor_parallel.mappings.gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True)

core.tensor_parallel.mappings.gather_from_sequence_parallel_region_to_moe(input_)

core.tensor_parallel.mappings.gather_from_tensor_model_parallel_region(input_)

core.tensor_parallel.mappings.reduce_from_tensor_model_parallel_region(input_)

core.tensor_parallel.mappings.reduce_scatter_to_sequence_parallel_region(input_)

core.tensor_parallel.mappings.reduce_scatter_to_sequence_parallel_region_from_moe(input_)

core.tensor_parallel.mappings.scatter_to_sequence_parallel_region(input_)

core.tensor_parallel.mappings.scatter_to_tensor_model_parallel_region(input_)

class core.tensor_parallel.random.CheckpointFunction(*args: Any, **kwargs: Any)

Bases: torch.autograd.Function

Checkpoint Function

This function is adapted from torch.utils.checkpoint with two main changes: 1) torch.cuda.set_rng_state is replaced with _set_cuda_rng_state 2) the states in the model parallel tracker are also properly tracked/set/reset.

static backward(ctx, *args)

static forward(ctx, run_function, distribute_saved_activations, *args)

class core.tensor_parallel.random.CudaRNGStatesTracker

Bases: object

Tracker for the cuda RNG states.

Using the add method, a cuda rng state is initialized based on the input seed and is assigned to name. Later, by forking the rng state, we can perform operations and return to our starting cuda state.

add(name, seed)

Track the rng state.

fork(name='model-parallel-rng')

Fork the cuda rng state, perform operations, and exit with the original state.

get_states()

Get rng states. Copy the dictionary so we have direct pointers to the states, not just a pointer to the dictionary.

reset()

Set to the initial state (no tracker).

set_states(states)

Set the rng states. For efficiency purposes, we do not check the size of seed for compatibility.

core.tensor_parallel.random.checkpoint(function, distribute_saved_activations, *args)

Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint.

core.tensor_parallel.random.get_cuda_rng_tracker()

Get cuda rng tracker.

core.tensor_parallel.random.get_data_parallel_rng_tracker_name()

core.tensor_parallel.random.get_expert_parallel_rng_tracker_name()

core.tensor_parallel.random.model_parallel_cuda_manual_seed(seed)

Initialize model parallel cuda seed.

This function should be called after the model parallel is initialized. Also, no torch.cuda.manual_seed should be called after this function. Basically, this is replacement for that function. Two set of RNG states are tracked: default state: This is for data parallelism and is the same among a set of model parallel GPUs but different across different model paralle groups. This is used for example for dropout in the non-tensor-model-parallel regions. tensor-model-parallel state: This state is different among a set of model parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions.

class core.tensor_parallel.utils.VocabUtility

Bases: object

Split the vocabulary into world_size chunks and return the first and last index of the vocabulary belonging to the rank partition: Note that indices in [fist, last)

static vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) → Sequence[int]

static vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int, rank, world_size: int) → Sequence[int]

core.tensor_parallel.utils.gather_split_1d_tensor(tensor)

Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor model parallel ranks.

Returns a new Tensor with the gathered data.

Parameters

tensor – A Tensor or view of this rank’s portion of the data.

core.tensor_parallel.utils.split_tensor_along_last_dim(tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False) → List[torch.Tensor]

Split a tensor along its last dimension.

Parameters
  • tensor – input tensor.

  • num_partitions – number of partitions to split the tensor

  • contiguous_split_chunks – If True, make each chunk contiguous in memory.

Returns

A list of Tensors

core.tensor_parallel.utils.split_tensor_into_1d_equal_chunks(tensor, new_buffer=False)

Break a tensor into equal 1D chunks across tensor parallel ranks.

Returns a Tensor or View with this rank’s portion of the data.

Parameters

tensor – The tensor to split

Keyword Arguments

new_buffer (bool) – If True, returns a new Tensor. If False, returns a view into the existing Tensor. Default is False

class core.tensor_parallel.ColumnParallelLinear(*args: Any, **kwargs: Any)

Bases: torch.nn.Module

Linear layer with column parallelism.

The linear layer is defined as Y = XA + b. A is parallelized along its second dimension as A = [A_1, …, A_p].

Parameters
  • input_size – first dimension of matrix A.

  • output_size – second dimension of matrix A.

  • bias – If true, add bias

  • gather_output – If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i

  • init_method – method to initialize weights. Note that bias is always set to zero.

  • stride – For the strided linear layers.

  • keep_master_weight_for_test – This was added for testing and should be set to False. It returns the master weights used for initialization.

  • skip_bias_add – If True, do not add the bias term, instead return it to be added by the caller. This enables performance optimations where bias can be fused with other elementwise operations.

  • skip_weight_param_allocation – If True, weight parameter is not allocated and must be passed as a keyword argument weight during the forward pass. Note that this does not affect bias, which will be allocated if bias is True. Defaults to False.

  • is_expert – If True, the layer is treated as an MoE expert layer.

  • config – ModelParallelConfig object

  • tp_comm_buffer_name – Communication buffer name is not used in non-Transformer-Engine modules.

forward(input_: torch.Tensor, weight: Optional[torch.Tensor] = None)

Forward of ColumnParallelLinear

Parameters
  • input – 3D tensor whose order of dimension is [sequence, batch, hidden]

  • weight (optional) – weight tensor to use, compulsory when skip_weight_param_allocation is True.

Returns

  • output

  • bias

get_extra_state() → None

Keep compatibility with TE state dict.

set_extra_state(state: Any)

Extra state is ignored

sharded_state_dict(prefix='', sharded_offsets=())

Sharding along axis 0, bias sharded

class core.tensor_parallel.RowParallelLinear(*args: Any, **kwargs: Any)

Bases: torch.nn.Module

Linear layer with row parallelism.

The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, …, X_p]

Parameters
  • input_size – first dimension of matrix A.

  • output_size – second dimension of matrix A.

  • bias – If true, add bias. Note that bias is not parallelized.

  • input_is_parallel – If true, we assume that the input is already split across the GPUs and we do not split again.

  • init_method – method to initialize weights. Note that bias is always set to zero.

  • stride – For the strided linear layers.

  • keep_master_weight_for_test – This was added for testing and should be set to False. It returns the master weights used for initialization.

  • skip_bias_add – If True, do not add the bias term, instead return it to be added by the caller. This enables performance optimations where bias can be fused with other elementwise operations.

  • is_expert – If True, the layer is treated as an MoE expert layer

  • tp_comm_buffer_name – Communication buffer name. Not used in non-Transformer-Engine modules.

  • config – ModelParallelConfig object

forward(input_)

Forward of RowParallelLinear

Parameters

input – 3D tensor whose order of dimension is [sequence, batch, hidden]

Returns

  • output

  • bias

get_extra_state() → None

Keep compatibility with TE state dict.

set_extra_state(state: Any)

Extra state is ignored

sharded_state_dict(prefix='', sharded_offsets=())

Sharding along axis 1, bias not sharded

class core.tensor_parallel.VocabParallelEmbedding(*args: Any, **kwargs: Any)

Bases: torch.nn.Module

Embedding parallelized in the vocabulary dimension.

This is mainly adapted from torch.nn.Embedding and all the default values are kept. :param num_embeddings: vocabulary size. :param embedding_dim: size of hidden state.

Keyword Arguments

config – A megatron.core.ModelParallelConfig object

forward(input_)

sharded_state_dict(prefix: str = '', sharded_offsets: Tuple[Tuple[int, int, int]] = ()) → Dict[str, Any]

Non-default implementation for embeddings due to allow_shape_mismatch param

core.tensor_parallel.broadcast_data(keys, data, datatype)

Broadcast data from rank zero of each model parallel group to the members of the same model parallel group.

Parameters
  • keys – list of keys in the data disctionary to be broadcasted

  • data – data dictionary of string keys and cpu tensor values.

  • datatype – torch data type of all tensors in data associated with keys.

core.tensor_parallel.checkpoint(function, distribute_saved_activations, *args)

Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint.

core.tensor_parallel.copy_tensor_model_parallel_attributes(destination_tensor, source_tensor)

core.tensor_parallel.copy_to_tensor_model_parallel_region(input_)

core.tensor_parallel.gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True)

core.tensor_parallel.gather_from_sequence_parallel_region_to_moe(input_)

core.tensor_parallel.gather_from_tensor_model_parallel_region(input_)

core.tensor_parallel.gather_split_1d_tensor(tensor)

Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor model parallel ranks.

Returns a new Tensor with the gathered data.

Parameters

tensor – A Tensor or view of this rank’s portion of the data.

core.tensor_parallel.get_cuda_rng_tracker()

Get cuda rng tracker.

core.tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], gradient_accumulation_fusion: bool, async_grad_allreduce: bool, sequence_parallel: bool) → torch.Tensor

Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop.

This has the option to accumulate the result of backprop calculation into an existing gradient buffer, preventing the need to do an additional addition kernel after the gradient calculation.

Additionally, the tensor parallel all reduce of the input gradients can be done asynchronously with the calculation of the weight gradients.

In the case of sequence parallelism, the reduce scatter of the input gradients is done asynchronously with the calcluation of the weight gradients.

Use of this module requires that the environment variable CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective operations, noted in the code, that should be scheduled before compute kernels to overlap the communication with the computation, which is necessary for a speedup but not for correctness so that ordering isn’t imposed by the scheduler. Setting CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled in the order they are called.

Arguments:

input (torch.Tensor required): input like torch.nn.functional.linear

weight (torch.Tensor required): weight like torch.nn.functional.linear

bias (torch.Tensor optional): bias like torch.nn.functional.linear

gradient_accumulation_fusion (bool required): Perform the gradient

accumulation fusion, requires the custom CUDA extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with –cpp_ext and –cuda_ext. For example: “pip install –global-option=”–cpp_ext” –global-option=”–cuda_ext .” ” Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion.”

async_grad_allreduce (bool required): Do the allreduce of input

gradients asyncronously with the computation of weight gradients. If sequence_parallel is True, this must be False, as no all reduce is performed.

sequence_parallel (bool required): Indicates that sequence

parallelism is used and thus in the forward pass the input is all gathered, and the backward pass the input gradients are reduce scattered.

core.tensor_parallel.model_parallel_cuda_manual_seed(seed)

Initialize model parallel cuda seed.

This function should be called after the model parallel is initialized. Also, no torch.cuda.manual_seed should be called after this function. Basically, this is replacement for that function. Two set of RNG states are tracked: default state: This is for data parallelism and is the same among a set of model parallel GPUs but different across different model paralle groups. This is used for example for dropout in the non-tensor-model-parallel regions. tensor-model-parallel state: This state is different among a set of model parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions.

core.tensor_parallel.param_is_not_tensor_parallel_duplicate(param)

core.tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe(input_)

core.tensor_parallel.scatter_to_sequence_parallel_region(input_)

core.tensor_parallel.scatter_to_tensor_model_parallel_region(input_)

core.tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(tensor)

core.tensor_parallel.set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride)

core.tensor_parallel.split_tensor_along_last_dim(tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False) → List[torch.Tensor]

Split a tensor along its last dimension.

Parameters
  • tensor – input tensor.

  • num_partitions – number of partitions to split the tensor

  • contiguous_split_chunks – If True, make each chunk contiguous in memory.

Returns

A list of Tensors

core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor, new_buffer=False)

Break a tensor into equal 1D chunks across tensor parallel ranks.

Returns a Tensor or View with this rank’s portion of the data.

Parameters

tensor – The tensor to split

Keyword Arguments

new_buffer (bool) – If True, returns a new Tensor. If False, returns a view into the existing Tensor. Default is False

core.tensor_parallel.vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0)

Performs cross entropy loss when logits are split across tensor parallel ranks

Parameters
  • vocab_parallel_logits – logits split across tensor parallel ranks dimension is [sequence_length, batch_size, hidden_size]

  • target – correct vocab ids of dimseion [sequence_length, micro_batch_size]

  • lobal_smoothing – smoothing factor, must be in range [0.0, 1.0) default is no smoothing (=0.0)

Previous models.bert package
Next Context parallelism overview
© Copyright 2022-2024, NVIDIA. Last updated on Mar 16, 2024.