core.tensor_parallel.utils#

Module Contents#

Classes#

VocabUtility

Split the vocabulary into world_size chunks and return the first and last index of the vocabulary belonging to the rank partition: Note that indices in [fist, last)

Functions#

split_tensor_along_last_dim

Split a tensor along its last dimension.

split_tensor_into_1d_equal_chunks

Break a tensor into equal 1D chunks across tensor parallel ranks.

gather_split_1d_tensor

Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor model parallel ranks.

API#

core.tensor_parallel.utils.split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) List[torch.Tensor]#

Split a tensor along its last dimension.

Parameters:
  • tensor – input tensor.

  • num_partitions – number of partitions to split the tensor

  • contiguous_split_chunks – If True, make each chunk contiguous in memory.

Returns:

A list of Tensors

core.tensor_parallel.utils.split_tensor_into_1d_equal_chunks(
tensor,
new_buffer=False,
tp_group=None,
)#

Break a tensor into equal 1D chunks across tensor parallel ranks.

Returns a Tensor or View with this rank’s portion of the data.

Parameters:

tensor – The tensor to split

Keyword Arguments:

new_buffer (bool) – If True, returns a new Tensor. If False, returns a view into the existing Tensor. Default is False

core.tensor_parallel.utils.gather_split_1d_tensor(tensor, tp_group=None)#

Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor model parallel ranks.

Returns a new Tensor with the gathered data.

Parameters:

tensor – A Tensor or view of this rank’s portion of the data.

class core.tensor_parallel.utils.VocabUtility#

Split the vocabulary into world_size chunks and return the first and last index of the vocabulary belonging to the rank partition: Note that indices in [fist, last)

static vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int,
rank,
world_size: int,
) Sequence[int]#

Vocab range from per partition vocab size.

static vocab_range_from_global_vocab_size(
global_vocab_size: int,
rank: int,
world_size: int,
) Sequence[int]#

Vocab range from global vocab size.