core.tensor_parallel.data#
Module Contents#
Functions#
Check that all the keys have the same target data type. |
|
Build the size on rank 0 and broadcast. |
|
Broadcast data from rank zero of each model parallel group to the members of the same model parallel group. |
Data#
API#
- core.tensor_parallel.data._MAX_DATA_DIM#
5
- core.tensor_parallel.data._check_data_types(keys, data, target_dtype)#
Check that all the keys have the same target data type.
- core.tensor_parallel.data._build_key_size_numel_dictionaries(keys, data, tp_group=None)#
Build the size on rank 0 and broadcast.
- core.tensor_parallel.data.broadcast_data(keys, data, datatype, tp_group=None)#
Broadcast data from rank zero of each model parallel group to the members of the same model parallel group.
- Parameters:
keys – list of keys in the data disctionary to be broadcasted
data – data dictionary of string keys and cpu tensor values.
datatype – torch data type of all tensors in data associated with keys.
tp_group – the tensor model parallel group to broadcast to.