core.dist_checkpointing.exchange_utils#

Utilities for exchanging data between ranks.

Module Contents#

Classes#

ShardDistribution

Represents a distribution of ShardedTensors.

Functions#

is_float8tensor

Check if a tensor is a Transformer Engine Float8Tensor

_shard_size

Returns size in bytes of a given sharded tensor.

_get_empty_tensor_for_exchange

Determines the empty tensor to use for exchange.

distribute_shards_to_ranks

Computes uniform distribution of workload across ranks, based on sizes.

determine_main_replica_uniform_distribution

Computes the save distribution.

exchange_loaded_tensors_gather_rounds

Exchange the tensors loaded by different ranks with several all_gather calls.

exchange_loaded_tensors_gather_object

Exchange the tensors loaded by different ranks with a simple all_gather_object call.

exchange_loaded_objects_gather_object

Exchange the objects loaded by different ranks with a simple all_gather_object call.

exchange_loaded_tensors_broadcast

Exchange the tensors loaded by different ranks by a series of broadcasts.

exchange_by_distribution

Exchange tensors loaded by different ranks using the specified exchange_algo.

Data#

API#

core.dist_checkpointing.exchange_utils.is_float8tensor(tensor: torch.Tensor) bool#

Check if a tensor is a Transformer Engine Float8Tensor

core.dist_checkpointing.exchange_utils.logger#

‘getLogger(…)’

class core.dist_checkpointing.exchange_utils.ShardDistribution#

Bases: typing.NamedTuple

Represents a distribution of ShardedTensors.

Given distribution is valid only for a specific parallelization group, which is implicit here (not referenced by this class).

Parameters:
  • main_rank_for_shard (Dict[_ShardId, int]) – specifies which rank should hold the main replica for a given shard

  • shards_in_this_group (Set[_ShardId]) – which shards have a main replica in this parallelization group

  • shard_to_metadata (Dict[_ShardId, ShardedTensor]) – maps ShardedTensor identifier to the original ShardedTensor

  • all_ranks_for_shard (Dict[_ShardId, List[int]]) – specifies which ranks need a given shard in a given parallelization group

main_rank_for_shard: Dict[core.dist_checkpointing.utils._ShardId, int]#

None

shards_in_this_group: Set[core.dist_checkpointing.utils._ShardId]#

None

shard_to_metadata: Dict[core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.mapping.ShardedTensor]#

None

all_ranks_for_shard: Dict[core.dist_checkpointing.utils._ShardId, List[int]]#

None

core.dist_checkpointing.exchange_utils._shard_size(sh_ten: core.dist_checkpointing.mapping.ShardedTensor)#

Returns size in bytes of a given sharded tensor.

core.dist_checkpointing.exchange_utils._get_empty_tensor_for_exchange(
shard_id: core.dist_checkpointing.utils._ShardId,
needed_shards: Dict[core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.mapping.ShardedTensor],
unneeded_shards: Dict[core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.mapping.ShardedTensor],
loaded_tensors: Dict[core.dist_checkpointing.utils._ShardId, torch.Tensor],
) Tuple[torch.Tensor, Optional[torch.device]]#

Determines the empty tensor to use for exchange.

If shard_id is needed by this rank, it will be in the unloaded_shards. Otherwise, the metadata for this tensor can be found in shard_to_metadata

Parameters:
  • shard_id (_ShardId) – shard_id that will be exchanged

  • needed_shards (Dict[_ShardId, ShardedTensor]) – mapping from shard ids to metadata for shards needed by this rank

  • unneeded_shards (Dict[_ShardId, ShardedTensor]) – mapping from shard ids to metadata for shards that can be discarded after exchange

  • loaded_tensors (Dict[_ShardId, torch.Tensor]) – mapping where useful tensors are placed in

Returns:

empty CUDA tensor to be exchanged, and the device of the original state dict tensor (if there was any)

Return type:

Tuple[torch.Tensor, Optional[torch.device]]

core.dist_checkpointing.exchange_utils.T#

‘TypeVar(…)’

core.dist_checkpointing.exchange_utils.distribute_shards_to_ranks(
shard_to_ranks: Dict[core.dist_checkpointing.exchange_utils.T, List[int]],
shard_to_size: Dict[core.dist_checkpointing.exchange_utils.T, int],
num_ranks: int,
cross_parallelization_group_loads: Set[core.dist_checkpointing.exchange_utils.T],
) Dict[core.dist_checkpointing.exchange_utils.T, int]#

Computes uniform distribution of workload across ranks, based on sizes.

Currently, the assignment is greedy, based on:

  1. Cross-parallelization group dependencies (shards with main rank in another group are assigned at the end to make sure the distribution for load and save is as similar as possible).

  2. Secondly, the coverage of each shard (how many ranks the shard is available on; lower coverage is assigned first)

  3. Then, the size of each shard (larger size is assigned first)

  4. Finally, shard id for differentiation.

Last step is added because we rely on the fact that the assignment is deterministic on all ranks.

Parameters:
  • shard_to_ranks (Dict[T, List[int]]) – mapping of rank access to shards

  • shard_to_size (Dict[T, int]) – sizes of each shard

  • num_ranks (int) – number of ranks in the parallelization group

  • cross_parallelization_group_loads (Set[T]) – Shards to load that are not in the main replica

Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work to achieve maximal uniformity)

core.dist_checkpointing.exchange_utils.determine_main_replica_uniform_distribution(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
parallelization_group: torch.distributed.ProcessGroup,
ignore_groups: bool = False,
) Optional[core.dist_checkpointing.exchange_utils.ShardDistribution]#

Computes the save distribution.

Should be used in conjunction with distribute_main_replicas_with_precomputed_distribution which applies the computed save distribution.

We rely on the fact that the assignment algorithm is deterministic on all ranks, so there is no extra communication needed after metadata exchange.

Parameters:
  • sharded_state_dict (ShardedStateDict) – state dict to compute the distribution of

  • parallelization_group (ProcessGroup) – distribution will be computed within this process group

  • ignore_groups (bool, optional) – whether the distribution defines groups. This option is primarily used during loading, as it ensures that all replicas, including non-main ones, are loaded by this parallelization group Defaults to False.

Returns (ShardDistribution, optional): distribution that can be used to apply the parallelization. Returns None if the process_group is trivial (1 rank)

core.dist_checkpointing.exchange_utils.exchange_loaded_tensors_gather_rounds(
loaded_tensors: Dict[core.dist_checkpointing.utils._ShardId, torch.Tensor],
unloaded_shards: Dict[core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.mapping.ShardedTensor],
shard_distribution: core.dist_checkpointing.exchange_utils.ShardDistribution = None,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) Dict[core.dist_checkpointing.utils._ShardId, torch.Tensor]#

Exchange the tensors loaded by different ranks with several all_gather calls.

Groups tensors by dtype, divide tensors that will be exchanged into rounds and execute all_gather for tensors from each round.

Note: the loading is distributed across ranks based on total loaded size in bytes, so there is no guarantee that number of rounds needed for each rank will be similar, which might result in a lot of almost empty all_gathers. The solution would be to group all tensors into a one bytes tensor and do a single all_gather (with similarly sized messages).

Parameters:
  • loaded_tensors (Dict[_ShardId, torch.Tensor]) – mapping from ShardedTensor shard ids to tensors already loaded by this rank.

  • unloaded_shards (Dict[_ShardId, torch.Tensor]) – mapping from ShardedTensor shard ids to ShardedTensors that aren’t loaded yet.

  • shard_distribution (ShardDistribution) – distribution of all shards

  • parallelization_group (ProcessGroup, optional) – process group used for load distribution. Tensors will be exchanged within this group

Returns:

dictionary mapping shard ids to tensors needed by this rank to load a given state dict. Includes previously loaded tensors (from loaded_tensors input)

Return type:

Dict[_ShardId, torch.Tensor]

core.dist_checkpointing.exchange_utils.exchange_loaded_tensors_gather_object(
loaded_tensors: Dict[core.dist_checkpointing.utils._ShardId, torch.Tensor],
unloaded_shards: Dict[core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.mapping.ShardedTensor],
shard_distribution: core.dist_checkpointing.exchange_utils.ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) Dict[core.dist_checkpointing.utils._ShardId, torch.Tensor]#

Exchange the tensors loaded by different ranks with a simple all_gather_object call.

This version can be used for debugging purposes do to its simplistic implementation. Shouldn’t be used if performance is important.

Parameters:
  • loaded_tensors (Dict[_ShardId, torch.Tensor]) – mapping from ShardedTensor shard ids to tensors already loaded by this rank.

  • unloaded_shards (Dict[_ShardId, torch.Tensor]) – mapping from ShardedTensor shard ids to ShardedTensors that aren’t loaded yet.

  • shard_distribution (ShardDistribution) – distribution of all shards

  • parallelization_group (ProcessGroup, optional) – process group used for load distribution. Tensors will be exchanged within this group

Returns:

dictionary mapping shard ids to tensors needed by this rank to load a given state dict. Includes previously loaded tensors (from loaded_tensors input)

Return type:

Dict[_ShardId, torch.Tensor]

core.dist_checkpointing.exchange_utils.exchange_loaded_objects_gather_object(
loaded_objects: Dict[core.dist_checkpointing.utils._ShardId, Any],
) Dict[core.dist_checkpointing.utils._ShardId, Any]#

Exchange the objects loaded by different ranks with a simple all_gather_object call.

Parameters:

loaded_objects (Dict[_ShardId, Any]) – mapping from shard ids to objects already loaded by this rank.

Returns:

dictionary mapping shard ids to objects needed by this rank to load a given state dict.

Return type:

Dict[_ShardId, Any]

core.dist_checkpointing.exchange_utils.exchange_loaded_tensors_broadcast(
loaded_tensors: Dict[core.dist_checkpointing.utils._ShardId, torch.Tensor],
unloaded_shards: Dict[core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.mapping.ShardedTensor],
shard_distribution: core.dist_checkpointing.exchange_utils.ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) Dict[core.dist_checkpointing.utils._ShardId, torch.Tensor]#

Exchange the tensors loaded by different ranks by a series of broadcasts.

For each rank for each loaded tensor do a broadcast to the whole group. A reasonable tradeoff in terms of performance and simplicity.

Parameters:
  • loaded_tensors (Dict[_ShardId, torch.Tensor]) – mapping from ShardedTensor shard ids to tensors already loaded by this rank.

  • unloaded_shards (Dict[_ShardId, ShardedTensor]) – mapping from ShardedTensor shard ids to ShardedTensors that aren’t loaded yet.

  • shard_distribution (ShardDistribution) – distribution of all shards

  • parallelization_group (ProcessGroup, optional) – process group used for load distribution. Tensors will be exchanged within this group

Returns:

dictionary mapping shard ids to tensors needed by this rank to load a given state dict. Includes previously loaded tensors (from loaded_tensors input)

Return type:

Dict[_ShardId, torch.Tensor]

core.dist_checkpointing.exchange_utils.exchange_by_distribution(
loaded_tensors: Dict[core.dist_checkpointing.utils._ShardId, torch.Tensor],
unloaded_shards: Dict[core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.mapping.ShardedTensor],
shard_distribution: core.dist_checkpointing.exchange_utils.ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
exchange_algo='broadcast',
) Dict[core.dist_checkpointing.utils._ShardId, torch.Tensor]#

Exchange tensors loaded by different ranks using the specified exchange_algo.

Parameters:
  • loaded_tensors (Dict[_ShardId, torch.Tensor]) – mapping from ShardedTensor shard ids to tensors already loaded by this rank.

  • unloaded_shards (Dict[_ShardId, ShardedTensor]) – mapping from ShardedTensor shard ids to ShardedTensors that aren’t loaded yet.

  • shard_distribution (ShardDistribution) – distribution of all shards

  • parallelization_group (ProcessGroup, optional) – process group used for load distribution. Tensors will be exchanged within this group

  • exchange_algo (str) – The algorithm used for performing exchanges. Defaults to ‘broadcast’.

Returns:

dictionary mapping shard ids to tensors needed by this rank to load a given state dict. Includes previously loaded tensors (from loaded_tensors input)

Return type:

Dict[_ShardId, torch.Tensor]