core.tensor_parallel.layers#
Module Contents#
Classes#
Embedding parallelized in the vocabulary dimension. |
|
Linear operator that does not calculate gradient for weight. This op and LinearWithGradAccumulationAndAsyncCommunication performs mathematically-identical forward and DGRAD. |
|
See linear_with_grad_accumulation_and_async_allreduce |
|
Linear layer with column parallelism. |
|
Linear layer with row parallelism. |
Functions#
Returns true if the passed-in parameter is not a duplicate parameter on another TP rank. |
|
Sets tp attributes to tensor |
|
Set default model parallel attributes if not set explicitly already. |
|
Copy model parallel attributes from one tensor to another. |
|
Initialize affine weight for model parallel on GPU. |
|
Initialize affine weight for model parallel. |
|
Linear layer execution with weight.requires_grad == False. |
|
Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop. |
Data#
API#
- core.tensor_parallel.layers._grad_accum_fusion_available#
True
- core.tensor_parallel.layers._MODEL_PARALLEL_ATTRIBUTE_DEFAULTS#
None
- core.tensor_parallel.layers.param_is_not_tensor_parallel_duplicate(param)#
Returns true if the passed-in parameter is not a duplicate parameter on another TP rank.
- core.tensor_parallel.layers.set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride)#
Sets tp attributes to tensor
- core.tensor_parallel.layers.set_defaults_if_not_set_tensor_model_parallel_attributes(tensor)#
Set default model parallel attributes if not set explicitly already.
- core.tensor_parallel.layers.copy_tensor_model_parallel_attributes(
- destination_tensor,
- source_tensor,
Copy model parallel attributes from one tensor to another.
- core.tensor_parallel.layers._initialize_affine_weight_gpu(
- weight,
- init_method,
- partition_dim,
- stride=1,
- is_expert=False,
Initialize affine weight for model parallel on GPU.
- core.tensor_parallel.layers._initialize_affine_weight_cpu(
- weight,
- output_size,
- input_size,
- per_partition_size,
- partition_dim,
- init_method,
- stride=1,
- return_master_weight=False,
- *,
- params_dtype=torch.float32,
- rank=None,
- world_size=None,
- skip_set_tensor_parallel_attributes=False,
Initialize affine weight for model parallel.
Build the master weight on all processes and scatter the relevant chunk.
- class core.tensor_parallel.layers.VocabParallelEmbedding(
- num_embeddings: int,
- embedding_dim: int,
- *,
- init_method: Callable,
- reduce_scatter_embeddings: bool = False,
- config: megatron.core.model_parallel_config.ModelParallelConfig,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
torch.nn.ModuleEmbedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default values are kept.
- Parameters:
num_embeddings – vocabulary size.
embedding_dim – size of hidden state.
reduce_scatter_embeddings – Decides whether to perform ReduceScatter after embedding lookup
- Keyword Arguments:
config – A megatron.core.ModelParallelConfig object
Initialization
- forward(input_)#
Forward.
- Parameters:
input_ (torch.Tensor) – Input tensor.
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: Tuple[Tuple[int, int, int]] = (),
- metadata: Optional[dict] = None,
Non-default implementation for embeddings due to
allow_shape_mismatchparam
- class core.tensor_parallel.layers.LinearWithFrozenWeight#
Bases:
torch.autograd.FunctionLinear operator that does not calculate gradient for weight. This op and LinearWithGradAccumulationAndAsyncCommunication performs mathematically-identical forward and DGRAD.
Conceptually this op is the same as torch.nn.functional.linear with weight.requires_grad==False, but in experiments they are not identical mathematically.
- static forward(ctx, input, weight, bias, allreduce_dgrad, tp_group)#
Forward with frozen weight.
- static backward(ctx, grad_output)#
Backward with frozen weight.
- core.tensor_parallel.layers.linear_with_frozen_weight(
- input: torch.Tensor,
- weight: torch.Tensor,
- bias: Optional[torch.Tensor],
- gradient_accumulation_fusion: bool,
- allreduce_dgrad: bool,
- sequence_parallel: bool,
- tp_group: Optional[torch.distributed.ProcessGroup],
- grad_output_buffer: Optional[List[torch.Tensor]] = None,
- wgrad_deferral_limit: None = None,
- async_grad_allreduce: Optional[bool] = None,
Linear layer execution with weight.requires_grad == False.
This function handles linear layers with weight frozen (untrainable). In the forward, it only saves weight and does not save input activations. In the backward, it does not perform weight gradient calculation, or weight gradient allreduce.
Args:
input (torch.Tensor required): input like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): dummy argument, used to keep the API unified between all forward implementation functions.
allreduce_dgrad (bool, required): Do the allreduce of input gradients. Here, async and sync allreduce are the same. If sequence_parallel is True, this must be False, as no all reduce is performed.
sequence_parallel (bool required): Indicates that sequence parallelism is used and thus in the forward pass the input is all gathered, and the backward pass the input gradients are reduce scattered.
tp_group (torch.distributed.ProcessGroup): The process group to use for tensor parallel operations.
grad_output_buffer (List[torch.Tensor] optional): dummy argument, used to keep the API unified between all forward implementation functions.
wgrad_deferral_limit (int optional): dummy argument, used to keep the API unified between all forward implementation functions.
async_grad_allreduce (bool optional): Will be removed with 0.11.0. Please use allreduce_dgrad instead.
- class core.tensor_parallel.layers.LinearWithGradAccumulationAndAsyncCommunication#
Bases:
torch.autograd.FunctionSee linear_with_grad_accumulation_and_async_allreduce
- static forward(
- ctx,
- input,
- weight,
- bias,
- gradient_accumulation_fusion,
- allreduce_dgrad,
- sequence_parallel,
- grad_output_buffer,
- wgrad_deferral_limit,
- tp_group,
Forward.
- static backward(ctx, grad_output)#
Backward.
- core.tensor_parallel.layers.linear_with_grad_accumulation_and_async_allreduce(
- input: torch.Tensor,
- weight: torch.Tensor,
- bias: Optional[torch.Tensor],
- gradient_accumulation_fusion: bool,
- allreduce_dgrad: bool,
- sequence_parallel: bool,
- grad_output_buffer: Optional[List[torch.Tensor]] = None,
- wgrad_deferral_limit: Optional[int] = 0,
- async_grad_allreduce: Optional[bool] = None,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop.
This has the option to accumulate the result of backprop calculation into an existing gradient buffer, preventing the need to do an additional addition kernel after the gradient calculation.
Additionally, the tensor parallel all reduce of the input gradients can be done asynchronously with the calculation of the weight gradients.
In the case of sequence parallelism, the reduce scatter of the input gradients is done asynchronously with the calculation of the weight gradients.
Use of this module requires that the environment variable CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective operations, noted in the code, that should be scheduled before compute kernels to overlap the communication with the computation, which is necessary for a speedup but not for correctness so that ordering isn’t imposed by the scheduler. Setting CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled in the order they are called.
- Parameters:
input (torch.Tensor required) – input like torch.nn.functional.linear
weight (torch.Tensor required) – weight like torch.nn.functional.linear
bias (torch.Tensor optional) – bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required) – Perform the gradient accumulation fusion, requires the custom CUDA extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with –cpp_ext and –cuda_ext. For example: “pip install –global-option=”–cpp_ext” –global-option=”–cuda_ext .” “ Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion.”
allreduce_dgrad (bool required) – Do the allreduce of input gradients. The allreduce is done asynchronously with the computation of weight gradients. If sequence_parallel is True, this must be False, as no all reduce is performed.
sequence_parallel (bool required) – Indicates that sequence parallelism is used and thus in the forward pass the input is all gathered, and the backward pass the input gradients are reduce scattered.
tp_group (torch.distributed.ProcessGroup required) – The process group to use for tensor parallel operations.
grad_output_buffer (List[torch.Tensor] optional) – Buffer used to save output gradients when embedding table wgrad compute is deferred. Defaults to None.
wgrad_deferral_limit (int optional) – Limit on the number of micro-batches for which embedding weight gradient GEMM should be deferred. Disable by setting this to 0. Defaults to 0.
async_grad_allreduce (bool optional) – Will be removed with 0.11.0. Please use allreduce_dgrad instead.
- class core.tensor_parallel.layers.ColumnParallelLinear(
- input_size,
- output_size,
- *,
- config: megatron.core.model_parallel_config.ModelParallelConfig,
- init_method: Callable,
- bias=True,
- gather_output=False,
- stride=1,
- keep_master_weight_for_test=False,
- skip_bias_add=False,
- skip_weight_param_allocation: bool = False,
- embedding_activation_buffer: Optional[List[torch.Tensor]] = None,
- grad_output_buffer: Optional[List[torch.Tensor]] = None,
- is_expert: bool = False,
- tp_comm_buffer_name: str = None,
- disable_grad_reduce: bool = False,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
torch.nn.ModuleLinear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along its second dimension as A = [A_1, …, A_p].
- Parameters:
input_size – first dimension of matrix A.
output_size – second dimension of matrix A.
bias – If true, add bias
gather_output – If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i
init_method – method to initialize weights. Note that bias is always set to zero.
stride – For the strided linear layers.
keep_master_weight_for_test – This was added for testing and should be set to False. It returns the master weights used for initialization.
skip_bias_add – If True, do not add the bias term, instead return it to be added by the caller. This enables performance optimizations where bias can be fused with other elementwise operations.
skip_weight_param_allocation – If True, weight parameter is not allocated and must be passed as a keyword argument
weightduring the forward pass. Note that this does not affect bias, which will be allocated if bias is True. Defaults to False.embedding_activation_buffer – This buffer holds the input activations of the final embedding linear layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
grad_output_buffer – This buffer holds the gradient outputs of the final embedding linear layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
is_expert – If True, the layer is treated as an MoE expert layer.
config – ModelParallelConfig object
tp_comm_buffer_name – Communication buffer name is not used in non-Transformer-Engine modules.
disable_grad_reduce – If True, reduction of output gradients across tensor-parallel ranks will be disabled. Defaults to False. This feature is used by Lora Adapter in Nemo to delay and fuse reduction along with other gradients for performance optimization.
Initialization
- _forward_impl(input, weight, *args, **kwargs)#
- forward(
- input_: torch.Tensor,
- weight: Optional[torch.Tensor] = None,
- runtime_gather_output: Optional[bool] = None,
Forward of ColumnParallelLinear
- Parameters:
input_ – 3D tensor whose order of dimension is [sequence, batch, hidden]
weight (optional) – weight tensor to use, compulsory when skip_weight_param_allocation is True.
runtime_gather_output (bool) – Gather output at runtime. Default None means
gather_outputarg in the constructor will be used.
- Returns:
output
bias
- sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#
Sharding along axis 0, bias sharded
- set_extra_state(state: Any)#
Extra state is ignored
- get_extra_state() None#
Keep compatibility with TE state dict.
- __repr__()#
- class core.tensor_parallel.layers.RowParallelLinear(
- input_size: int,
- output_size: int,
- *,
- config: megatron.core.model_parallel_config.ModelParallelConfig,
- init_method: Callable,
- bias: bool,
- input_is_parallel: bool,
- skip_bias_add: bool,
- stride: int = 1,
- keep_master_weight_for_test: bool = False,
- is_expert: bool = False,
- tp_comm_buffer_name: str = None,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
torch.nn.ModuleLinear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, …, X_p]
- Parameters:
input_size – first dimension of matrix A.
output_size – second dimension of matrix A.
bias – If true, add bias. Note that bias is not parallelized.
input_is_parallel – If true, we assume that the input is already split across the GPUs and we do not split again.
init_method – method to initialize weights. Note that bias is always set to zero.
stride – For the strided linear layers.
keep_master_weight_for_test – This was added for testing and should be set to False. It returns the master weights used for initialization.
skip_bias_add – If True, do not add the bias term, instead return it to be added by the caller. This enables performance optimizations where bias can be fused with other elementwise operations.
is_expert – If True, the layer is treated as an MoE expert layer
tp_comm_buffer_name – Communication buffer name. Not used in non-Transformer-Engine modules.
config – ModelParallelConfig object
Initialization
- _forward_impl(input, weight, *args, **kwargs)#
- forward(input_)#
Forward of RowParallelLinear
- Parameters:
input_ – 3D tensor whose order of dimension is [sequence, batch, hidden]
- Returns:
output
bias
- sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#
Sharding along axis 1, bias not sharded
- set_extra_state(state: Any)#
Extra state is ignored
- get_extra_state() None#
Keep compatibility with TE state dict.
- __repr__()#