core.tensor_parallel.layers#

Module Contents#

Classes#

VocabParallelEmbedding

Embedding parallelized in the vocabulary dimension.

LinearWithFrozenWeight

Linear operator that does not calculate gradient for weight. This op and LinearWithGradAccumulationAndAsyncCommunication performs mathematically-identical forward and DGRAD.

LinearWithGradAccumulationAndAsyncCommunication

See linear_with_grad_accumulation_and_async_allreduce

ColumnParallelLinear

Linear layer with column parallelism.

RowParallelLinear

Linear layer with row parallelism.

Functions#

param_is_not_tensor_parallel_duplicate

Returns true if the passed-in parameter is not a duplicate parameter on another TP rank.

set_tensor_model_parallel_attributes

Sets tp attributes to tensor

set_defaults_if_not_set_tensor_model_parallel_attributes

Set default model parallel attributes if not set explicitly already.

copy_tensor_model_parallel_attributes

Copy model parallel attributes from one tensor to another.

_initialize_affine_weight_gpu

Initialize affine weight for model parallel on GPU.

_initialize_affine_weight_cpu

Initialize affine weight for model parallel.

linear_with_frozen_weight

Linear layer execution with weight.requires_grad == False.

linear_with_grad_accumulation_and_async_allreduce

Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop.

Data#

API#

core.tensor_parallel.layers._grad_accum_fusion_available#

True

core.tensor_parallel.layers._MODEL_PARALLEL_ATTRIBUTE_DEFAULTS#

None

core.tensor_parallel.layers.param_is_not_tensor_parallel_duplicate(param)#

Returns true if the passed-in parameter is not a duplicate parameter on another TP rank.

core.tensor_parallel.layers.set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride)#

Sets tp attributes to tensor

core.tensor_parallel.layers.set_defaults_if_not_set_tensor_model_parallel_attributes(tensor)#

Set default model parallel attributes if not set explicitly already.

core.tensor_parallel.layers.copy_tensor_model_parallel_attributes(
destination_tensor,
source_tensor,
)#

Copy model parallel attributes from one tensor to another.

core.tensor_parallel.layers._initialize_affine_weight_gpu(
weight,
init_method,
partition_dim,
stride=1,
is_expert=False,
)#

Initialize affine weight for model parallel on GPU.

core.tensor_parallel.layers._initialize_affine_weight_cpu(
weight,
output_size,
input_size,
per_partition_size,
partition_dim,
init_method,
stride=1,
return_master_weight=False,
*,
params_dtype=torch.float32,
rank=None,
world_size=None,
skip_set_tensor_parallel_attributes=False,
)#

Initialize affine weight for model parallel.

Build the master weight on all processes and scatter the relevant chunk.

class core.tensor_parallel.layers.VocabParallelEmbedding(
num_embeddings: int,
embedding_dim: int,
*,
init_method: Callable,
reduce_scatter_embeddings: bool = False,
config: megatron.core.model_parallel_config.ModelParallelConfig,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: torch.nn.Module

Embedding parallelized in the vocabulary dimension.

This is mainly adapted from torch.nn.Embedding and all the default values are kept.

Parameters:
  • num_embeddings – vocabulary size.

  • embedding_dim – size of hidden state.

  • reduce_scatter_embeddings – Decides whether to perform ReduceScatter after embedding lookup

Keyword Arguments:

config – A megatron.core.ModelParallelConfig object

Initialization

forward(input_)#

Forward.

Parameters:

input_ (torch.Tensor) – Input tensor.

sharded_state_dict(
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) core.dist_checkpointing.mapping.ShardedStateDict#

Non-default implementation for embeddings due to allow_shape_mismatch param

class core.tensor_parallel.layers.LinearWithFrozenWeight#

Bases: torch.autograd.Function

Linear operator that does not calculate gradient for weight. This op and LinearWithGradAccumulationAndAsyncCommunication performs mathematically-identical forward and DGRAD.

Conceptually this op is the same as torch.nn.functional.linear with weight.requires_grad==False, but in experiments they are not identical mathematically.

static forward(ctx, input, weight, bias, allreduce_dgrad, tp_group)#

Forward with frozen weight.

static backward(ctx, grad_output)#

Backward with frozen weight.

core.tensor_parallel.layers.linear_with_frozen_weight(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
allreduce_dgrad: bool,
sequence_parallel: bool,
tp_group: Optional[torch.distributed.ProcessGroup],
grad_output_buffer: Optional[List[torch.Tensor]] = None,
wgrad_deferral_limit: None = None,
async_grad_allreduce: Optional[bool] = None,
) torch.Tensor#

Linear layer execution with weight.requires_grad == False.

This function handles linear layers with weight frozen (untrainable). In the forward, it only saves weight and does not save input activations. In the backward, it does not perform weight gradient calculation, or weight gradient allreduce.

Args:

input (torch.Tensor required): input like torch.nn.functional.linear

weight (torch.Tensor required): weight like torch.nn.functional.linear

bias (torch.Tensor optional): bias like torch.nn.functional.linear

gradient_accumulation_fusion (bool required): dummy argument, used to keep the API unified between all forward implementation functions.

allreduce_dgrad (bool, required): Do the allreduce of input gradients. Here, async and sync allreduce are the same. If sequence_parallel is True, this must be False, as no all reduce is performed.

sequence_parallel (bool required): Indicates that sequence parallelism is used and thus in the forward pass the input is all gathered, and the backward pass the input gradients are reduce scattered.

tp_group (torch.distributed.ProcessGroup): The process group to use for tensor parallel operations.

grad_output_buffer (List[torch.Tensor] optional): dummy argument, used to keep the API unified between all forward implementation functions.

wgrad_deferral_limit (int optional): dummy argument, used to keep the API unified between all forward implementation functions.

async_grad_allreduce (bool optional): Will be removed with 0.11.0. Please use allreduce_dgrad instead.

class core.tensor_parallel.layers.LinearWithGradAccumulationAndAsyncCommunication#

Bases: torch.autograd.Function

See linear_with_grad_accumulation_and_async_allreduce

static forward(
ctx,
input,
weight,
bias,
gradient_accumulation_fusion,
allreduce_dgrad,
sequence_parallel,
grad_output_buffer,
wgrad_deferral_limit,
tp_group,
)#

Forward.

static backward(ctx, grad_output)#

Backward.

core.tensor_parallel.layers.linear_with_grad_accumulation_and_async_allreduce(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
allreduce_dgrad: bool,
sequence_parallel: bool,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
wgrad_deferral_limit: Optional[int] = 0,
async_grad_allreduce: Optional[bool] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
) torch.Tensor#

Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop.

This has the option to accumulate the result of backprop calculation into an existing gradient buffer, preventing the need to do an additional addition kernel after the gradient calculation.

Additionally, the tensor parallel all reduce of the input gradients can be done asynchronously with the calculation of the weight gradients.

In the case of sequence parallelism, the reduce scatter of the input gradients is done asynchronously with the calculation of the weight gradients.

Use of this module requires that the environment variable CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective operations, noted in the code, that should be scheduled before compute kernels to overlap the communication with the computation, which is necessary for a speedup but not for correctness so that ordering isn’t imposed by the scheduler. Setting CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled in the order they are called.

Parameters:
  • input (torch.Tensor required) – input like torch.nn.functional.linear

  • weight (torch.Tensor required) – weight like torch.nn.functional.linear

  • bias (torch.Tensor optional) – bias like torch.nn.functional.linear

  • gradient_accumulation_fusion (bool required) – Perform the gradient accumulation fusion, requires the custom CUDA extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with –cpp_ext and –cuda_ext. For example: “pip install –global-option=”–cpp_ext” –global-option=”–cuda_ext .” “ Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion.”

  • allreduce_dgrad (bool required) – Do the allreduce of input gradients. The allreduce is done asynchronously with the computation of weight gradients. If sequence_parallel is True, this must be False, as no all reduce is performed.

  • sequence_parallel (bool required) – Indicates that sequence parallelism is used and thus in the forward pass the input is all gathered, and the backward pass the input gradients are reduce scattered.

  • tp_group (torch.distributed.ProcessGroup required) – The process group to use for tensor parallel operations.

  • grad_output_buffer (List[torch.Tensor] optional) – Buffer used to save output gradients when embedding table wgrad compute is deferred. Defaults to None.

  • wgrad_deferral_limit (int optional) – Limit on the number of micro-batches for which embedding weight gradient GEMM should be deferred. Disable by setting this to 0. Defaults to 0.

  • async_grad_allreduce (bool optional) – Will be removed with 0.11.0. Please use allreduce_dgrad instead.

class core.tensor_parallel.layers.ColumnParallelLinear(
input_size,
output_size,
*,
config: megatron.core.model_parallel_config.ModelParallelConfig,
init_method: Callable,
bias=True,
gather_output=False,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
skip_weight_param_allocation: bool = False,
embedding_activation_buffer: Optional[List[torch.Tensor]] = None,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
disable_grad_reduce: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: torch.nn.Module

Linear layer with column parallelism.

The linear layer is defined as Y = XA + b. A is parallelized along its second dimension as A = [A_1, …, A_p].

Parameters:
  • input_size – first dimension of matrix A.

  • output_size – second dimension of matrix A.

  • bias – If true, add bias

  • gather_output – If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i

  • init_method – method to initialize weights. Note that bias is always set to zero.

  • stride – For the strided linear layers.

  • keep_master_weight_for_test – This was added for testing and should be set to False. It returns the master weights used for initialization.

  • skip_bias_add – If True, do not add the bias term, instead return it to be added by the caller. This enables performance optimizations where bias can be fused with other elementwise operations.

  • skip_weight_param_allocation – If True, weight parameter is not allocated and must be passed as a keyword argument weight during the forward pass. Note that this does not affect bias, which will be allocated if bias is True. Defaults to False.

  • embedding_activation_buffer – This buffer holds the input activations of the final embedding linear layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.

  • grad_output_buffer – This buffer holds the gradient outputs of the final embedding linear layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.

  • is_expert – If True, the layer is treated as an MoE expert layer.

  • config – ModelParallelConfig object

  • tp_comm_buffer_name – Communication buffer name is not used in non-Transformer-Engine modules.

  • disable_grad_reduce – If True, reduction of output gradients across tensor-parallel ranks will be disabled. Defaults to False. This feature is used by Lora Adapter in Nemo to delay and fuse reduction along with other gradients for performance optimization.

Initialization

_forward_impl(input, weight, *args, **kwargs)#
forward(
input_: torch.Tensor,
weight: Optional[torch.Tensor] = None,
runtime_gather_output: Optional[bool] = None,
)#

Forward of ColumnParallelLinear

Parameters:
  • input_ – 3D tensor whose order of dimension is [sequence, batch, hidden]

  • weight (optional) – weight tensor to use, compulsory when skip_weight_param_allocation is True.

  • runtime_gather_output (bool) – Gather output at runtime. Default None means gather_output arg in the constructor will be used.

Returns:

  • output

  • bias

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Sharding along axis 0, bias sharded

set_extra_state(state: Any)#

Extra state is ignored

get_extra_state() None#

Keep compatibility with TE state dict.

__repr__()#
class core.tensor_parallel.layers.RowParallelLinear(
input_size: int,
output_size: int,
*,
config: megatron.core.model_parallel_config.ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
stride: int = 1,
keep_master_weight_for_test: bool = False,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: torch.nn.Module

Linear layer with row parallelism.

The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, …, X_p]

Parameters:
  • input_size – first dimension of matrix A.

  • output_size – second dimension of matrix A.

  • bias – If true, add bias. Note that bias is not parallelized.

  • input_is_parallel – If true, we assume that the input is already split across the GPUs and we do not split again.

  • init_method – method to initialize weights. Note that bias is always set to zero.

  • stride – For the strided linear layers.

  • keep_master_weight_for_test – This was added for testing and should be set to False. It returns the master weights used for initialization.

  • skip_bias_add – If True, do not add the bias term, instead return it to be added by the caller. This enables performance optimizations where bias can be fused with other elementwise operations.

  • is_expert – If True, the layer is treated as an MoE expert layer

  • tp_comm_buffer_name – Communication buffer name. Not used in non-Transformer-Engine modules.

  • config – ModelParallelConfig object

Initialization

_forward_impl(input, weight, *args, **kwargs)#
forward(input_)#

Forward of RowParallelLinear

Parameters:

input_ – 3D tensor whose order of dimension is [sequence, batch, hidden]

Returns:

  • output

  • bias

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Sharding along axis 1, bias not sharded

set_extra_state(state: Any)#

Extra state is ignored

get_extra_state() None#

Keep compatibility with TE state dict.

__repr__()#