core.tensor_parallel.mappings#

Module Contents#

Classes#

_CopyToModelParallelRegion

Pass the input to the model parallel region.

_ReduceFromModelParallelRegion

All-reduce the input from the model parallel region.

_ScatterToModelParallelRegion

Split the input and keep only the corresponding chuck to the rank.

_GatherFromModelParallelRegion

Gather the input from model parallel region and concatinate.

_ScatterToSequenceParallelRegion

Split the input and keep only the corresponding chuck to the rank.

_GatherFromSequenceParallelRegion

Gather the input from sequence parallel region and concatinate.

_ReduceScatterToSequenceParallelRegion

Reduce scatter the input from the model parallel region.

_AllGatherFromTensorParallelRegion

Gather the input from model parallel region and concatenate.

_ReduceScatterToTensorParallelRegion

Reduce scatter the input from the model parallel region.

_AllToAll

Functions#

_reduce

All-reduce the input tensor across model parallel group.

_split_along_last_dim

Split the tensor along its last dimension and keep the corresponding slice.

_split_along_first_dim

Split the tensor along its first dimension and keep the corresponding slice.

_gather_along_last_dim

Gather tensors and concatinate along the last dimension.

_reduce_scatter_along_last_dim

Reduce-scatter tensors on the last dimension.

_gather_along_first_dim

Gather tensors and concatenate along the first dimension.

_reduce_scatter_along_first_dim

Reduce-scatter the input tensor across model parallel group.

copy_to_tensor_model_parallel_region

Wrapper for autograd function: forward: copy, backward allreduce

reduce_from_tensor_model_parallel_region

Wrapper for autograd function: forward: all reduce, backward copy

scatter_to_tensor_model_parallel_region

Wrapper for autograd function: forward: RS, backward: AG

gather_from_tensor_model_parallel_region

Wrapper for autograd function: forward: AG, backward: split

scatter_to_sequence_parallel_region

Wrapper for autograd function: forward: split, backward: AG

gather_from_sequence_parallel_region

Wrapper for autograd function: forward: AG, backward: RS

reduce_scatter_to_sequence_parallel_region

Wrapper for autograd function: forward: RS, backward AG

all_gather_last_dim_from_tensor_parallel_region

Wrapper for autograd function: forward: AG, backward RS

reduce_scatter_last_dim_to_tensor_parallel_region

Wrapper for autograd function: forward: RS, backward AG: AG

all_to_all

Wrapper for autograd function

all_to_all_sp2hp

Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape [num_tokens/TP, H] to [num_tokens, H/TP].

all_to_all_hp2sp

Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape [num_tokens, H/TP] to [num_tokens/TP, H].

API#

core.tensor_parallel.mappings._reduce(input_, group)#

All-reduce the input tensor across model parallel group.

core.tensor_parallel.mappings._split_along_last_dim(input_, group)#

Split the tensor along its last dimension and keep the corresponding slice.

core.tensor_parallel.mappings._split_along_first_dim(input_, group)#

Split the tensor along its first dimension and keep the corresponding slice.

core.tensor_parallel.mappings._gather_along_last_dim(input_, group)#

Gather tensors and concatinate along the last dimension.

core.tensor_parallel.mappings._reduce_scatter_along_last_dim(input_, group)#

Reduce-scatter tensors on the last dimension.

core.tensor_parallel.mappings._gather_along_first_dim(
input_,
group,
output_split_sizes=None,
use_global_buffer=False,
)#

Gather tensors and concatenate along the first dimension.

Parameters:
  • input_tensor (torch.Tensor) – A tensor to be gathered.

  • output_split_sizes (List[int], optional) – A list specifying the sizes of the output splits along the first dimension. If None, equal splitting is assumed. Default: None.

Returns:

Gathered tensor.

Return type:

torch.Tensor

core.tensor_parallel.mappings._reduce_scatter_along_first_dim(
input_,
group,
input_split_sizes=None,
use_global_buffer=False,
)#

Reduce-scatter the input tensor across model parallel group.

Parameters:
  • input_ (torch.Tensor) – The input tensor to be reduce-scattered.

  • input_split_sizes (List[int], optional) – A list specifying the sizes of the input splits along the first dimension for each rank. If None, equal splitting is assumed. Default: None.

class core.tensor_parallel.mappings._CopyToModelParallelRegion#

Bases: torch.autograd.Function

Pass the input to the model parallel region.

static symbolic(graph, input_, group)#

Symbolic function for tracing.

static forward(ctx, input_, group)#

Forward function.

static backward(ctx, grad_output)#

Backward function.

class core.tensor_parallel.mappings._ReduceFromModelParallelRegion#

Bases: torch.autograd.Function

All-reduce the input from the model parallel region.

static symbolic(graph, input_, group)#

Symbolic function for tracing.

static forward(ctx, input_, group)#

Forward function.

static backward(ctx, grad_output)#

Backward function.

class core.tensor_parallel.mappings._ScatterToModelParallelRegion#

Bases: torch.autograd.Function

Split the input and keep only the corresponding chuck to the rank.

static symbolic(graph, input_, group)#

Symbolic function for tracing.

static forward(ctx, input_, group)#

Forward function.

static backward(ctx, grad_output)#

Backward function.

class core.tensor_parallel.mappings._GatherFromModelParallelRegion#

Bases: torch.autograd.Function

Gather the input from model parallel region and concatinate.

static symbolic(graph, input_, group)#

Symbolic function for tracing.

static forward(ctx, input_, group)#

Forward function.

static backward(ctx, grad_output)#

Backward function.

class core.tensor_parallel.mappings._ScatterToSequenceParallelRegion#

Bases: torch.autograd.Function

Split the input and keep only the corresponding chuck to the rank.

static symbolic(graph, input_, group)#

Symbolic function for tracing.

static forward(ctx, input_, group)#

Forward function.

static backward(ctx, grad_output)#

Backward function.

class core.tensor_parallel.mappings._GatherFromSequenceParallelRegion#

Bases: torch.autograd.Function

Gather the input from sequence parallel region and concatinate.

static symbolic(
graph,
input_,
group,
tensor_parallel_output_grad=True,
output_split_sizes=None,
use_global_buffer=False,
)#

Symbolic function for tracing.

static forward(
ctx,
input_,
group,
tensor_parallel_output_grad=True,
output_split_sizes=None,
use_global_buffer=False,
)#

Forward function.

static backward(ctx, grad_output)#

Backward function.

class core.tensor_parallel.mappings._ReduceScatterToSequenceParallelRegion#

Bases: torch.autograd.Function

Reduce scatter the input from the model parallel region.

static symbolic(
graph,
input_,
group,
input_split_sizes=None,
use_global_buffer=False,
)#

Symbolic function for tracing.

static forward(
ctx,
input_,
group,
input_split_sizes=None,
use_global_buffer=False,
)#

Forward function.

static backward(ctx, grad_output)#

Backward function.

class core.tensor_parallel.mappings._AllGatherFromTensorParallelRegion#

Bases: torch.autograd.Function

Gather the input from model parallel region and concatenate.

static symbolic(graph, input_, group)#

Symbolic function for tracing.

static forward(ctx, input_, group)#

Forward function.

static backward(ctx, grad_output)#

Backward function.

class core.tensor_parallel.mappings._ReduceScatterToTensorParallelRegion#

Bases: torch.autograd.Function

Reduce scatter the input from the model parallel region.

static symbolic(graph, input_, group)#

Symbolic function for tracing.

static forward(ctx, input_, group)#

Forward function.

static backward(ctx, grad_output)#

Backward function.

class core.tensor_parallel.mappings._AllToAll#

Bases: torch.autograd.Function

static forward(ctx, group, input, output_split_sizes, input_split_sizes)#

Forward function.

static backward(ctx, *grad_output)#

Backward function.

core.tensor_parallel.mappings.copy_to_tensor_model_parallel_region(input_, group=None)#

Wrapper for autograd function: forward: copy, backward allreduce

core.tensor_parallel.mappings.reduce_from_tensor_model_parallel_region(input_, group=None)#

Wrapper for autograd function: forward: all reduce, backward copy

core.tensor_parallel.mappings.scatter_to_tensor_model_parallel_region(input_, group=None)#

Wrapper for autograd function: forward: RS, backward: AG

core.tensor_parallel.mappings.gather_from_tensor_model_parallel_region(input_, group=None)#

Wrapper for autograd function: forward: AG, backward: split

core.tensor_parallel.mappings.scatter_to_sequence_parallel_region(input_, group=None)#

Wrapper for autograd function: forward: split, backward: AG

core.tensor_parallel.mappings.gather_from_sequence_parallel_region(
input_,
tensor_parallel_output_grad=True,
group=None,
output_split_sizes=None,
use_global_buffer=False,
)#

Wrapper for autograd function: forward: AG, backward: RS

core.tensor_parallel.mappings.reduce_scatter_to_sequence_parallel_region(
input_,
group=None,
input_split_sizes=None,
use_global_buffer=False,
)#

Wrapper for autograd function: forward: RS, backward AG

core.tensor_parallel.mappings.all_gather_last_dim_from_tensor_parallel_region(input_, group=None)#

Wrapper for autograd function: forward: AG, backward RS

core.tensor_parallel.mappings.reduce_scatter_last_dim_to_tensor_parallel_region(input_, group=None)#

Wrapper for autograd function: forward: RS, backward AG: AG

core.tensor_parallel.mappings.all_to_all(
group,
input_,
output_split_sizes_=None,
input_split_sizes=None,
)#

Wrapper for autograd function

core.tensor_parallel.mappings.all_to_all_sp2hp(input_, group=None)#

Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape [num_tokens/TP, H] to [num_tokens, H/TP].

Parameters:
  • input_ (torch.Tensor) – The input tensor which has been distributed along the sequence dimension.

  • group (torch.distributed.ProcessGroup, optional) – The process group to work on. If None, the tensor model parallel group will be used.

Returns:

The output tensor with shape [num_tokens, H/TP].

Return type:

torch.Tensor

core.tensor_parallel.mappings.all_to_all_hp2sp(input_, group=None)#

Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape [num_tokens, H/TP] to [num_tokens/TP, H].

Parameters:
  • input_ (torch.Tensor) – The input tensor which has been distributed along the hidden dimension.

  • group (torch.distributed.ProcessGroup, optional) – The process group to work on. If None, the tensor model parallel group will be used.

Returns:

The output tensor with shape [num_tokens/TP, H].

Return type:

torch.Tensor