core.dist_checkpointing.exchange_utils#
Utilities for exchanging data between ranks.
Module Contents#
Classes#
Represents a distribution of ShardedTensors. |
Functions#
Check if a tensor is a Transformer Engine Float8Tensor |
|
Returns size in bytes of a given sharded tensor. |
|
Determines the empty tensor to use for exchange. |
|
Computes uniform distribution of workload across ranks, based on sizes. |
|
Computes the save distribution. |
|
Exchange the tensors loaded by different ranks with several all_gather calls. |
|
Exchange the tensors loaded by different ranks with a simple all_gather_object call. |
|
Exchange the objects loaded by different ranks with a simple all_gather_object call. |
|
Exchange the tensors loaded by different ranks by a series of broadcasts. |
|
Exchange tensors loaded by different ranks using the specified exchange_algo. |
Data#
API#
- core.dist_checkpointing.exchange_utils.is_float8tensor(tensor: torch.Tensor) bool#
Check if a tensor is a Transformer Engine Float8Tensor
- core.dist_checkpointing.exchange_utils.logger#
‘getLogger(…)’
- class core.dist_checkpointing.exchange_utils.ShardDistribution#
Bases:
typing.NamedTupleRepresents a distribution of ShardedTensors.
Given distribution is valid only for a specific parallelization group, which is implicit here (not referenced by this class).
- Parameters:
main_rank_for_shard (Dict[_ShardId, int]) – specifies which rank should hold the main replica for a given shard
shards_in_this_group (Set[_ShardId]) – which shards have a main replica in this parallelization group
shard_to_metadata (Dict[_ShardId, ShardedTensor]) – maps ShardedTensor identifier to the original ShardedTensor
all_ranks_for_shard (Dict[_ShardId, List[int]]) – specifies which ranks need a given shard in a given parallelization group
- main_rank_for_shard: Dict[core.dist_checkpointing.utils._ShardId, int]#
None
- shards_in_this_group: Set[core.dist_checkpointing.utils._ShardId]#
None
- shard_to_metadata: Dict[core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.mapping.ShardedTensor]#
None
- all_ranks_for_shard: Dict[core.dist_checkpointing.utils._ShardId, List[int]]#
None
- core.dist_checkpointing.exchange_utils._shard_size(sh_ten: core.dist_checkpointing.mapping.ShardedTensor)#
Returns size in bytes of a given sharded tensor.
- core.dist_checkpointing.exchange_utils._get_empty_tensor_for_exchange(
- shard_id: core.dist_checkpointing.utils._ShardId,
- needed_shards: Dict[core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.mapping.ShardedTensor],
- unneeded_shards: Dict[core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.mapping.ShardedTensor],
- loaded_tensors: Dict[core.dist_checkpointing.utils._ShardId, torch.Tensor],
Determines the empty tensor to use for exchange.
If shard_id is needed by this rank, it will be in the
unloaded_shards. Otherwise, the metadata for this tensor can be found inshard_to_metadata- Parameters:
shard_id (_ShardId) – shard_id that will be exchanged
needed_shards (Dict[_ShardId, ShardedTensor]) – mapping from shard ids to metadata for shards needed by this rank
unneeded_shards (Dict[_ShardId, ShardedTensor]) – mapping from shard ids to metadata for shards that can be discarded after exchange
loaded_tensors (Dict[_ShardId, torch.Tensor]) – mapping where useful tensors are placed in
- Returns:
empty CUDA tensor to be exchanged, and the device of the original state dict tensor (if there was any)
- Return type:
Tuple[torch.Tensor, Optional[torch.device]]
- core.dist_checkpointing.exchange_utils.T#
‘TypeVar(…)’
- core.dist_checkpointing.exchange_utils.distribute_shards_to_ranks(
- shard_to_ranks: Dict[core.dist_checkpointing.exchange_utils.T, List[int]],
- shard_to_size: Dict[core.dist_checkpointing.exchange_utils.T, int],
- num_ranks: int,
- cross_parallelization_group_loads: Set[core.dist_checkpointing.exchange_utils.T],
Computes uniform distribution of workload across ranks, based on sizes.
Currently, the assignment is greedy, based on:
Cross-parallelization group dependencies (shards with main rank in another group are assigned at the end to make sure the distribution for load and save is as similar as possible).
Secondly, the coverage of each shard (how many ranks the shard is available on; lower coverage is assigned first)
Then, the size of each shard (larger size is assigned first)
Finally, shard id for differentiation.
Last step is added because we rely on the fact that the assignment is deterministic on all ranks.
- Parameters:
shard_to_ranks (Dict[T, List[int]]) – mapping of rank access to shards
shard_to_size (Dict[T, int]) – sizes of each shard
num_ranks (int) – number of ranks in the parallelization group
cross_parallelization_group_loads (Set[T]) – Shards to load that are not in the main replica
Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work to achieve maximal uniformity)
- core.dist_checkpointing.exchange_utils.determine_main_replica_uniform_distribution(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- parallelization_group: torch.distributed.ProcessGroup,
- ignore_groups: bool = False,
Computes the save distribution.
Should be used in conjunction with
distribute_main_replicas_with_precomputed_distributionwhich applies the computed save distribution.We rely on the fact that the assignment algorithm is deterministic on all ranks, so there is no extra communication needed after metadata exchange.
- Parameters:
sharded_state_dict (ShardedStateDict) – state dict to compute the distribution of
parallelization_group (ProcessGroup) – distribution will be computed within this process group
ignore_groups (bool, optional) – whether the distribution defines groups. This option is primarily used during loading, as it ensures that all replicas, including non-main ones, are loaded by this parallelization group Defaults to False.
Returns (ShardDistribution, optional): distribution that can be used to apply the parallelization. Returns None if the process_group is trivial (1 rank)
- core.dist_checkpointing.exchange_utils.exchange_loaded_tensors_gather_rounds(
- loaded_tensors: Dict[core.dist_checkpointing.utils._ShardId, torch.Tensor],
- unloaded_shards: Dict[core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.mapping.ShardedTensor],
- shard_distribution: core.dist_checkpointing.exchange_utils.ShardDistribution = None,
- parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
Exchange the tensors loaded by different ranks with several all_gather calls.
Groups tensors by dtype, divide tensors that will be exchanged into rounds and execute all_gather for tensors from each round.
Note: the loading is distributed across ranks based on total loaded size in bytes, so there is no guarantee that number of rounds needed for each rank will be similar, which might result in a lot of almost empty all_gathers. The solution would be to group all tensors into a one bytes tensor and do a single all_gather (with similarly sized messages).
- Parameters:
loaded_tensors (Dict[_ShardId, torch.Tensor]) – mapping from ShardedTensor shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]) – mapping from ShardedTensor shard ids to ShardedTensors that aren’t loaded yet.
shard_distribution (ShardDistribution) – distribution of all shards
parallelization_group (ProcessGroup, optional) – process group used for load distribution. Tensors will be exchanged within this group
- Returns:
dictionary mapping shard ids to tensors needed by this rank to load a given state dict. Includes previously loaded tensors (from
loaded_tensorsinput)- Return type:
Dict[_ShardId, torch.Tensor]
- core.dist_checkpointing.exchange_utils.exchange_loaded_tensors_gather_object(
- loaded_tensors: Dict[core.dist_checkpointing.utils._ShardId, torch.Tensor],
- unloaded_shards: Dict[core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.mapping.ShardedTensor],
- shard_distribution: core.dist_checkpointing.exchange_utils.ShardDistribution,
- parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
Exchange the tensors loaded by different ranks with a simple all_gather_object call.
This version can be used for debugging purposes do to its simplistic implementation. Shouldn’t be used if performance is important.
- Parameters:
loaded_tensors (Dict[_ShardId, torch.Tensor]) – mapping from ShardedTensor shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]) – mapping from ShardedTensor shard ids to ShardedTensors that aren’t loaded yet.
shard_distribution (ShardDistribution) – distribution of all shards
parallelization_group (ProcessGroup, optional) – process group used for load distribution. Tensors will be exchanged within this group
- Returns:
dictionary mapping shard ids to tensors needed by this rank to load a given state dict. Includes previously loaded tensors (from
loaded_tensorsinput)- Return type:
Dict[_ShardId, torch.Tensor]
- core.dist_checkpointing.exchange_utils.exchange_loaded_objects_gather_object(
- loaded_objects: Dict[core.dist_checkpointing.utils._ShardId, Any],
Exchange the objects loaded by different ranks with a simple all_gather_object call.
- Parameters:
loaded_objects (Dict[_ShardId, Any]) – mapping from shard ids to objects already loaded by this rank.
- Returns:
dictionary mapping shard ids to objects needed by this rank to load a given state dict.
- Return type:
Dict[_ShardId, Any]
- core.dist_checkpointing.exchange_utils.exchange_loaded_tensors_broadcast(
- loaded_tensors: Dict[core.dist_checkpointing.utils._ShardId, torch.Tensor],
- unloaded_shards: Dict[core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.mapping.ShardedTensor],
- shard_distribution: core.dist_checkpointing.exchange_utils.ShardDistribution,
- parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
Exchange the tensors loaded by different ranks by a series of broadcasts.
For each rank for each loaded tensor do a broadcast to the whole group. A reasonable tradeoff in terms of performance and simplicity.
- Parameters:
loaded_tensors (Dict[_ShardId, torch.Tensor]) – mapping from ShardedTensor shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, ShardedTensor]) – mapping from ShardedTensor shard ids to ShardedTensors that aren’t loaded yet.
shard_distribution (ShardDistribution) – distribution of all shards
parallelization_group (ProcessGroup, optional) – process group used for load distribution. Tensors will be exchanged within this group
- Returns:
dictionary mapping shard ids to tensors needed by this rank to load a given state dict. Includes previously loaded tensors (from
loaded_tensorsinput)- Return type:
Dict[_ShardId, torch.Tensor]
- core.dist_checkpointing.exchange_utils.exchange_by_distribution(
- loaded_tensors: Dict[core.dist_checkpointing.utils._ShardId, torch.Tensor],
- unloaded_shards: Dict[core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.mapping.ShardedTensor],
- shard_distribution: core.dist_checkpointing.exchange_utils.ShardDistribution,
- parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
- exchange_algo='broadcast',
Exchange tensors loaded by different ranks using the specified exchange_algo.
- Parameters:
loaded_tensors (Dict[_ShardId, torch.Tensor]) – mapping from ShardedTensor shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, ShardedTensor]) – mapping from ShardedTensor shard ids to ShardedTensors that aren’t loaded yet.
shard_distribution (ShardDistribution) – distribution of all shards
parallelization_group (ProcessGroup, optional) – process group used for load distribution. Tensors will be exchanged within this group
exchange_algo (str) – The algorithm used for performing exchanges. Defaults to ‘broadcast’.
- Returns:
dictionary mapping shard ids to tensors needed by this rank to load a given state dict. Includes previously loaded tensors (from
loaded_tensorsinput)- Return type:
Dict[_ShardId, torch.Tensor]