core.tensor_parallel.mappings#
Module Contents#
Classes#
Pass the input to the model parallel region. |
|
All-reduce the input from the model parallel region. |
|
Split the input and keep only the corresponding chuck to the rank. |
|
Gather the input from model parallel region and concatinate. |
|
Split the input and keep only the corresponding chuck to the rank. |
|
Gather the input from sequence parallel region and concatinate. |
|
Reduce scatter the input from the model parallel region. |
|
Gather the input from model parallel region and concatenate. |
|
Reduce scatter the input from the model parallel region. |
|
Functions#
All-reduce the input tensor across model parallel group. |
|
Split the tensor along its last dimension and keep the corresponding slice. |
|
Split the tensor along its first dimension and keep the corresponding slice. |
|
Gather tensors and concatinate along the last dimension. |
|
Reduce-scatter tensors on the last dimension. |
|
Gather tensors and concatenate along the first dimension. |
|
Reduce-scatter the input tensor across model parallel group. |
|
Wrapper for autograd function: forward: copy, backward allreduce |
|
Wrapper for autograd function: forward: all reduce, backward copy |
|
Wrapper for autograd function: forward: RS, backward: AG |
|
Wrapper for autograd function: forward: AG, backward: split |
|
Wrapper for autograd function: forward: split, backward: AG |
|
Wrapper for autograd function: forward: AG, backward: RS |
|
Wrapper for autograd function: forward: RS, backward AG |
|
Wrapper for autograd function: forward: AG, backward RS |
|
Wrapper for autograd function: forward: RS, backward AG: AG |
|
Wrapper for autograd function |
|
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape [num_tokens/TP, H] to [num_tokens, H/TP]. |
|
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape [num_tokens, H/TP] to [num_tokens/TP, H]. |
API#
- core.tensor_parallel.mappings._reduce(input_, group)#
All-reduce the input tensor across model parallel group.
- core.tensor_parallel.mappings._split_along_last_dim(input_, group)#
Split the tensor along its last dimension and keep the corresponding slice.
- core.tensor_parallel.mappings._split_along_first_dim(input_, group)#
Split the tensor along its first dimension and keep the corresponding slice.
- core.tensor_parallel.mappings._gather_along_last_dim(input_, group)#
Gather tensors and concatinate along the last dimension.
- core.tensor_parallel.mappings._reduce_scatter_along_last_dim(input_, group)#
Reduce-scatter tensors on the last dimension.
- core.tensor_parallel.mappings._gather_along_first_dim(
- input_,
- group,
- output_split_sizes=None,
- use_global_buffer=False,
Gather tensors and concatenate along the first dimension.
- Parameters:
input_tensor (torch.Tensor) – A tensor to be gathered.
output_split_sizes (List[int], optional) – A list specifying the sizes of the output splits along the first dimension. If None, equal splitting is assumed. Default: None.
- Returns:
Gathered tensor.
- Return type:
torch.Tensor
- core.tensor_parallel.mappings._reduce_scatter_along_first_dim(
- input_,
- group,
- input_split_sizes=None,
- use_global_buffer=False,
Reduce-scatter the input tensor across model parallel group.
- Parameters:
input_ (torch.Tensor) – The input tensor to be reduce-scattered.
input_split_sizes (List[int], optional) – A list specifying the sizes of the input splits along the first dimension for each rank. If None, equal splitting is assumed. Default: None.
- class core.tensor_parallel.mappings._CopyToModelParallelRegion#
Bases:
torch.autograd.FunctionPass the input to the model parallel region.
- static symbolic(graph, input_, group)#
Symbolic function for tracing.
- static forward(ctx, input_, group)#
Forward function.
- static backward(ctx, grad_output)#
Backward function.
- class core.tensor_parallel.mappings._ReduceFromModelParallelRegion#
Bases:
torch.autograd.FunctionAll-reduce the input from the model parallel region.
- static symbolic(graph, input_, group)#
Symbolic function for tracing.
- static forward(ctx, input_, group)#
Forward function.
- static backward(ctx, grad_output)#
Backward function.
- class core.tensor_parallel.mappings._ScatterToModelParallelRegion#
Bases:
torch.autograd.FunctionSplit the input and keep only the corresponding chuck to the rank.
- static symbolic(graph, input_, group)#
Symbolic function for tracing.
- static forward(ctx, input_, group)#
Forward function.
- static backward(ctx, grad_output)#
Backward function.
- class core.tensor_parallel.mappings._GatherFromModelParallelRegion#
Bases:
torch.autograd.FunctionGather the input from model parallel region and concatinate.
- static symbolic(graph, input_, group)#
Symbolic function for tracing.
- static forward(ctx, input_, group)#
Forward function.
- static backward(ctx, grad_output)#
Backward function.
- class core.tensor_parallel.mappings._ScatterToSequenceParallelRegion#
Bases:
torch.autograd.FunctionSplit the input and keep only the corresponding chuck to the rank.
- static symbolic(graph, input_, group)#
Symbolic function for tracing.
- static forward(ctx, input_, group)#
Forward function.
- static backward(ctx, grad_output)#
Backward function.
- class core.tensor_parallel.mappings._GatherFromSequenceParallelRegion#
Bases:
torch.autograd.FunctionGather the input from sequence parallel region and concatinate.
- static symbolic(
- graph,
- input_,
- group,
- tensor_parallel_output_grad=True,
- output_split_sizes=None,
- use_global_buffer=False,
Symbolic function for tracing.
- static forward(
- ctx,
- input_,
- group,
- tensor_parallel_output_grad=True,
- output_split_sizes=None,
- use_global_buffer=False,
Forward function.
- static backward(ctx, grad_output)#
Backward function.
- class core.tensor_parallel.mappings._ReduceScatterToSequenceParallelRegion#
Bases:
torch.autograd.FunctionReduce scatter the input from the model parallel region.
- static symbolic(
- graph,
- input_,
- group,
- input_split_sizes=None,
- use_global_buffer=False,
Symbolic function for tracing.
- static forward(
- ctx,
- input_,
- group,
- input_split_sizes=None,
- use_global_buffer=False,
Forward function.
- static backward(ctx, grad_output)#
Backward function.
- class core.tensor_parallel.mappings._AllGatherFromTensorParallelRegion#
Bases:
torch.autograd.FunctionGather the input from model parallel region and concatenate.
- static symbolic(graph, input_, group)#
Symbolic function for tracing.
- static forward(ctx, input_, group)#
Forward function.
- static backward(ctx, grad_output)#
Backward function.
- class core.tensor_parallel.mappings._ReduceScatterToTensorParallelRegion#
Bases:
torch.autograd.FunctionReduce scatter the input from the model parallel region.
- static symbolic(graph, input_, group)#
Symbolic function for tracing.
- static forward(ctx, input_, group)#
Forward function.
- static backward(ctx, grad_output)#
Backward function.
- class core.tensor_parallel.mappings._AllToAll#
Bases:
torch.autograd.Function- static forward(ctx, group, input, output_split_sizes, input_split_sizes)#
Forward function.
- static backward(ctx, *grad_output)#
Backward function.
- core.tensor_parallel.mappings.copy_to_tensor_model_parallel_region(input_, group=None)#
Wrapper for autograd function: forward: copy, backward allreduce
- core.tensor_parallel.mappings.reduce_from_tensor_model_parallel_region(input_, group=None)#
Wrapper for autograd function: forward: all reduce, backward copy
- core.tensor_parallel.mappings.scatter_to_tensor_model_parallel_region(input_, group=None)#
Wrapper for autograd function: forward: RS, backward: AG
- core.tensor_parallel.mappings.gather_from_tensor_model_parallel_region(input_, group=None)#
Wrapper for autograd function: forward: AG, backward: split
- core.tensor_parallel.mappings.scatter_to_sequence_parallel_region(input_, group=None)#
Wrapper for autograd function: forward: split, backward: AG
- core.tensor_parallel.mappings.gather_from_sequence_parallel_region(
- input_,
- tensor_parallel_output_grad=True,
- group=None,
- output_split_sizes=None,
- use_global_buffer=False,
Wrapper for autograd function: forward: AG, backward: RS
- core.tensor_parallel.mappings.reduce_scatter_to_sequence_parallel_region(
- input_,
- group=None,
- input_split_sizes=None,
- use_global_buffer=False,
Wrapper for autograd function: forward: RS, backward AG
- core.tensor_parallel.mappings.all_gather_last_dim_from_tensor_parallel_region(input_, group=None)#
Wrapper for autograd function: forward: AG, backward RS
- core.tensor_parallel.mappings.reduce_scatter_last_dim_to_tensor_parallel_region(input_, group=None)#
Wrapper for autograd function: forward: RS, backward AG: AG
- core.tensor_parallel.mappings.all_to_all(
- group,
- input_,
- output_split_sizes_=None,
- input_split_sizes=None,
Wrapper for autograd function
- core.tensor_parallel.mappings.all_to_all_sp2hp(input_, group=None)#
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape [num_tokens/TP, H] to [num_tokens, H/TP].
- Parameters:
input_ (torch.Tensor) – The input tensor which has been distributed along the sequence dimension.
group (torch.distributed.ProcessGroup, optional) – The process group to work on. If None, the tensor model parallel group will be used.
- Returns:
The output tensor with shape [num_tokens, H/TP].
- Return type:
torch.Tensor
- core.tensor_parallel.mappings.all_to_all_hp2sp(input_, group=None)#
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape [num_tokens, H/TP] to [num_tokens/TP, H].
- Parameters:
input_ (torch.Tensor) – The input tensor which has been distributed along the hidden dimension.
group (torch.distributed.ProcessGroup, optional) – The process group to work on. If None, the tensor model parallel group will be used.
- Returns:
The output tensor with shape [num_tokens/TP, H].
- Return type:
torch.Tensor