core.dist_checkpointing.strategies.fully_parallel#

Module Contents#

Classes#

FullyParallelSaveStrategyWrapper

Wraps arbitrary strategy and distributes the save during save.

FullyParallelLoadStrategyWrapper

Wraps arbitrary load strategy and distributes the load during load.

Functions#

distribute_main_replicas_with_precomputed_distribution

Applies the save distribution computed with determine_main_replica_uniform_distribution.

_defer_loading_sharded_items

Divides state dict into parts loaded by this vs other ranks.

_fill_in_deferred_sharded_items

Helper function to fill in items not loaded by current rank.

Data#

API#

core.dist_checkpointing.strategies.fully_parallel.logger#

‘getLogger(…)’

core.dist_checkpointing.strategies.fully_parallel.T#

‘TypeVar(…)’

class core.dist_checkpointing.strategies.fully_parallel.FullyParallelSaveStrategyWrapper(
strategy: megatron.core.dist_checkpointing.strategies.base.SaveShardedStrategy,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
do_cache_distribution: bool = False,
)#

Bases: megatron.core.dist_checkpointing.strategies.base.AsyncSaveShardedStrategy

Wraps arbitrary strategy and distributes the save during save.

The save distribution happens without any data communication. Only the metadata is exchanged and based on data replication on different ranks, we try to distribute the save as uniformly as possible.

This wrapper assumes, that setting replica_id to 0 will make the underlying strategy do the saving on current rank. All the other replica_ids are set to 1.

Currently, the save distribution is realized with a greedy algorithm described in distribute_shards_to_ranks.

Parameters:
  • strategy (SaveShardedStrategy) – base strategy to wrap

  • parallelization_group (ProcessGroup, optional) – process group to use for save distribution. Note that this doesn’t have to match exactly the data distribution, but should cover the replication pattern to maximize performance. Defaults to the whole world.

  • do_cache_distribution (bool, optional) – whether to cache the save distribution from previous calls. Should be set to True only if the state dict structure between the calls is always the same. Defaults to True.

Initialization

async_save(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: pathlib.Path,
)#
save(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: pathlib.Path,
)#
apply_saving_parallelization(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
) None#

Distributes the save across ranks by exchanging metadata.

Exchanges metadata from the state dict and computes the uniform (as close as possible) distribution of saves among the ranks.

If self.do_cache_distribution is True, caches the distribution between the calls and subsequent distributions happen without any inter-rank communication.

Parameters:

sharded_state_dict (ShardedStateDict) – state dict to distribute the saving

Returns: None

property can_handle_sharded_objects#
class core.dist_checkpointing.strategies.fully_parallel.FullyParallelLoadStrategyWrapper(
strategy: megatron.core.dist_checkpointing.strategies.base.LoadShardedStrategy,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
do_cache_distribution: bool = False,
exchange_algo: str = 'broadcast',
)#

Bases: megatron.core.dist_checkpointing.strategies.base.LoadShardedStrategy

Wraps arbitrary load strategy and distributes the load during load.

See load method docs for details.

Parameters:
  • strategy (LoadShardedStrategy) – base strategy to wrap

  • parallelization_group (ProcessGroup, optional) – process group to use for load distribution. Note that this doesn’t have to match exactly the data distribution, but should cover the replication pattern to maximize performance. Defaults to the whole world. In most cases, it’s recommended to set it to the DP group.

  • do_cache_distribution (bool, optional) – whether to cache the load distribution from previous calls. Should be set to True only if the state dict structure between the calls is always the same. Defaults to False, since the loading in general happens only once during training. Note that the load distribution cannot be reused as a save distribution, because save/load is not fully symmetrical.

  • exchange_algo (str) –

    algorithm to use for exchanging the data. Options:

    • broadcast - each rank broadcasts individual tensors to others

    • gather_object (default) - ranks all_gather_object the whole loaded state dicts

    • gather_rounds (default) - ranks all gather individual tensors in rounds See method docs for more details.

Initialization

load(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: pathlib.Path,
) megatron.core.dist_checkpointing.mapping.StateDict#

Distributes the load and calls underlying strategy only for parts of the state dict.

Steps:

  1. Load metadata is exchanged between the ranks in the parallelization group.

  2. Each rank deterministically plans the load for the whole workload so that the loads are as uniform as possible.

  3. Each ranks loads its planned shard of the checkpoint.

  4. All ranks exchange the loaded shards.

Internode communication is involved in steps (1) (with metadata) and (4) (with actual data). Storage interaction is involved in step (3).

Currently, the load distribution (step 2) is realized with a greedy algorithm described in distribute_shards_to_ranks (same as for saving distribution).

Currently, the shards are all gathered between all ranks in the parallelization group. This might not be optimal (some ranks do not need all tensors), but it’s a reasonable approximation for an optimal exchange in most scenarios.

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict to load

  • checkpoint_dir (Path) – checkpoint directory to load from

Returns:

loaded state dict. The state dict should be equivalent to a state dict that would be loaded with the underlying strategy without this wrapper.

Return type:

StateDict

static _defer_loading_sharded_objects(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
) Tuple[megatron.core.dist_checkpointing.mapping.ShardedStateDict, megatron.core.dist_checkpointing.mapping.ShardedStateDict, Dict[megatron.core.dist_checkpointing.utils._ShardId, megatron.core.dist_checkpointing.ShardedObject], Dict[megatron.core.dist_checkpointing.utils._ShardId, megatron.core.dist_checkpointing.ShardedObject]]#
static _defer_loading_sharded_tensors(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
) Tuple[megatron.core.dist_checkpointing.mapping.ShardedStateDict, megatron.core.dist_checkpointing.mapping.ShardedStateDict, Dict[megatron.core.dist_checkpointing.utils._ShardId, megatron.core.dist_checkpointing.ShardedTensor], Dict[megatron.core.dist_checkpointing.utils._ShardId, megatron.core.dist_checkpointing.ShardedTensor]]#
static fill_in_deferred_sharded_objects(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
loaded_objects: Dict[megatron.core.dist_checkpointing.utils._ShardId, Any],
) None#

Fill in objects not loaded by current rank with objects from loaded_objects map.

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict to fill in. ShardedObjects are completely replaced with corresponding objects.

  • loaded_objects (Dict[_ShardId, Any]) – dict allowing to map ShardedObject from the sharded_state_dict to loaded objects.

Returns:

None

static fill_in_deferred_sharded_tensors(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
loaded_tensors: Dict[megatron.core.dist_checkpointing.utils._ShardId, torch.Tensor],
) None#

Fill in tensors not loaded by current rank with tensors from loaded_tensors map.

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict to fill in. ShardedTensors are completely replaced with corresponding torch.Tensors.

  • loaded_tensors (Dict[_ShardId, torch.Tensor]) – dict allowing to map ShardedTensor from the sharded_state_dict to loaded tensors.

Returns:

None

apply_loading_parallelization(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
) Optional[megatron.core.dist_checkpointing.exchange_utils.ShardDistribution]#

Distributes the load across ranks by exchanging metadata.

Exchanges metadata from the state dict and computes the uniform (as close as possible) distribution of loads among the ranks. Marks ShardedTensors to be loaded by the current rank with replica_id 0 (and others with non 0 values).

If self.do_cache_distribution is True, caches the distribution between the calls and subsequent distributions happen without any inter-rank communication.

Parameters:

sharded_state_dict (ShardedStateDict) – state dict to distribute the loading

Returns:

the computed loading distribution

Return type:

ShardDistribution (optional)

property can_handle_sharded_objects#
load_tensors_metadata(checkpoint_dir: pathlib.Path)#
load_sharded_metadata(checkpoint_dir: pathlib.Path)#
check_backend_compatibility(loaded_version)#
check_version_compatibility(loaded_version)#
core.dist_checkpointing.strategies.fully_parallel.distribute_main_replicas_with_precomputed_distribution(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
parallelization_group: torch.distributed.ProcessGroup,
precomputed_distribution: Optional[megatron.core.dist_checkpointing.exchange_utils.ShardDistribution],
)#

Applies the save distribution computed with determine_main_replica_uniform_distribution.

Based on rank assignment, sets replica ids of the shards saved by current rank to 0 and all the other replica ids to 1.

Parameters:
  • sharded_state_dict (ShardedStateDict) – state dict to apply the save distribution to

  • parallelization_group (ProcessGroup) – distribution will be applied within this process group. Must match with the process group passed to determine_main_replica_uniform_distribution.

  • precomputed_distribution (ShardDistribution) – distribution computed with determine_main_replica_uniform_distribution

Returns: None

Example replica ids of tensors A, B, C before distribution: rank0: A: (0, 0, 0), B: (0, 0, 0), C: (0, 0, 0) rank1: A: (0, 0, 1), B: (0, 0, 1), C: (0, 0, 1) rank2: A: (0, 0, 2), B: (0, 0, 2), C: (0, 0, 2)

Replicas after distribution for the example above: rank0: A: 0, B: 1, C: 1 rank1: A: 1, B: 0, C: 1 rank2: A: 1, B: 1, C: 0

core.dist_checkpointing.strategies.fully_parallel._defer_loading_sharded_items(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
item_type: type,
shard_id_func: Callable[[core.dist_checkpointing.strategies.fully_parallel.T], megatron.core.dist_checkpointing.utils._ShardId],
) Tuple[megatron.core.dist_checkpointing.mapping.ShardedStateDict, megatron.core.dist_checkpointing.mapping.ShardedStateDict, Dict[megatron.core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.strategies.fully_parallel.T], Dict[megatron.core.dist_checkpointing.utils._ShardId, core.dist_checkpointing.strategies.fully_parallel.T]]#

Divides state dict into parts loaded by this vs other ranks.

Parameters:
  • sharded_state_dict (ShardedStateDict) – state dict with sharded items that will be divided.

  • item_type – The type of sharded item (ShardedObject or ShardedTensor)

  • shard_id_func – Function to get the shard ID for the item type

Returns: a tuple of: - ShardedStateDict: sub-state dict only with sharded items - ShardedStateDict: sub-state dict with non-sharded items - Dict[_ShardId, T]: mapping from shard id to items loaded by this rank - Dict[_ShardId, T]: mapping from shard id to items loaded by other ranks

core.dist_checkpointing.strategies.fully_parallel._fill_in_deferred_sharded_items(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
loaded_items: Dict[megatron.core.dist_checkpointing.utils._ShardId, Any],
item_type: type,
shard_id_func: Callable[[core.dist_checkpointing.strategies.fully_parallel.T], megatron.core.dist_checkpointing.utils._ShardId],
) None#

Helper function to fill in items not loaded by current rank.