core.dist_checkpointing.strategies.fully_parallel#
Module Contents#
Classes#
Wraps arbitrary strategy and distributes the save during |
|
Wraps arbitrary load strategy and distributes the load during |
Functions#
Applies the save distribution computed with |
|
Divides state dict into parts loaded by this vs other ranks. |
|
Helper function to fill in items not loaded by current rank. |
Data#
API#
- core.dist_checkpointing.strategies.fully_parallel.logger#
‘getLogger(…)’
- core.dist_checkpointing.strategies.fully_parallel.T#
‘TypeVar(…)’
- class core.dist_checkpointing.strategies.fully_parallel.FullyParallelSaveStrategyWrapper(
- strategy: megatron.core.dist_checkpointing.strategies.base.SaveShardedStrategy,
- parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
- do_cache_distribution: bool = False,
Bases:
megatron.core.dist_checkpointing.strategies.base.AsyncSaveShardedStrategyWraps arbitrary strategy and distributes the save during
save.The save distribution happens without any data communication. Only the metadata is exchanged and based on data replication on different ranks, we try to distribute the save as uniformly as possible.
This wrapper assumes, that setting
replica_idto 0 will make the underlying strategy do the saving on current rank. All the otherreplica_ids are set to 1.Currently, the save distribution is realized with a greedy algorithm described in
distribute_shards_to_ranks.- Parameters:
strategy (SaveShardedStrategy) – base strategy to wrap
parallelization_group (ProcessGroup, optional) – process group to use for save distribution. Note that this doesn’t have to match exactly the data distribution, but should cover the replication pattern to maximize performance. Defaults to the whole world.
do_cache_distribution (bool, optional) – whether to cache the save distribution from previous calls. Should be set to True only if the state dict structure between the calls is always the same. Defaults to True.
Initialization
- async_save(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: pathlib.Path,
- save(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: pathlib.Path,
- apply_saving_parallelization(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
Distributes the save across ranks by exchanging metadata.
Exchanges metadata from the state dict and computes the uniform (as close as possible) distribution of saves among the ranks.
If
self.do_cache_distributionis True, caches the distribution between the calls and subsequent distributions happen without any inter-rank communication.- Parameters:
sharded_state_dict (ShardedStateDict) – state dict to distribute the saving
Returns: None
- property can_handle_sharded_objects#
- class core.dist_checkpointing.strategies.fully_parallel.FullyParallelLoadStrategyWrapper(
- strategy: megatron.core.dist_checkpointing.strategies.base.LoadShardedStrategy,
- parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
- do_cache_distribution: bool = False,
- exchange_algo: str = 'broadcast',
Bases:
megatron.core.dist_checkpointing.strategies.base.LoadShardedStrategyWraps arbitrary load strategy and distributes the load during
load.See
loadmethod docs for details.- Parameters:
strategy (LoadShardedStrategy) – base strategy to wrap
parallelization_group (ProcessGroup, optional) – process group to use for load distribution. Note that this doesn’t have to match exactly the data distribution, but should cover the replication pattern to maximize performance. Defaults to the whole world. In most cases, it’s recommended to set it to the DP group.
do_cache_distribution (bool, optional) – whether to cache the load distribution from previous calls. Should be set to True only if the state dict structure between the calls is always the same. Defaults to False, since the loading in general happens only once during training. Note that the load distribution cannot be reused as a save distribution, because save/load is not fully symmetrical.
exchange_algo (str) –
algorithm to use for exchanging the data. Options:
broadcast - each rank broadcasts individual tensors to others
gather_object (default) - ranks all_gather_object the whole loaded state dicts
gather_rounds (default) - ranks all gather individual tensors in rounds See method docs for more details.
Initialization
- load(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: pathlib.Path,
Distributes the load and calls underlying strategy only for parts of the state dict.
Steps:
Load metadata is exchanged between the ranks in the parallelization group.
Each rank deterministically plans the load for the whole workload so that the loads are as uniform as possible.
Each ranks loads its planned shard of the checkpoint.
All ranks exchange the loaded shards.
Internode communication is involved in steps (1) (with metadata) and (4) (with actual data). Storage interaction is involved in step (3).
Currently, the load distribution (step 2) is realized with a greedy algorithm described in
distribute_shards_to_ranks(same as for saving distribution).Currently, the shards are all gathered between all ranks in the parallelization group. This might not be optimal (some ranks do not need all tensors), but it’s a reasonable approximation for an optimal exchange in most scenarios.
- Parameters:
sharded_state_dict (ShardedStateDict) – sharded state dict to load
checkpoint_dir (Path) – checkpoint directory to load from
- Returns:
loaded state dict. The state dict should be equivalent to a state dict that would be loaded with the underlying strategy without this wrapper.
- Return type:
StateDict
- static _defer_loading_sharded_objects(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- static _defer_loading_sharded_tensors(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- static fill_in_deferred_sharded_objects(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- loaded_objects: Dict[megatron.core.dist_checkpointing.utils._ShardId, Any],
Fill in objects not loaded by current rank with objects from
loaded_objectsmap.- Parameters:
sharded_state_dict (ShardedStateDict) – sharded state dict to fill in. ShardedObjects are completely replaced with corresponding objects.
loaded_objects (Dict[_ShardId, Any]) – dict allowing to map ShardedObject from the sharded_state_dict to loaded objects.
- Returns:
None
- static fill_in_deferred_sharded_tensors(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- loaded_tensors: Dict[megatron.core.dist_checkpointing.utils._ShardId, torch.Tensor],
Fill in tensors not loaded by current rank with tensors from
loaded_tensorsmap.- Parameters:
sharded_state_dict (ShardedStateDict) – sharded state dict to fill in. ShardedTensors are completely replaced with corresponding torch.Tensors.
loaded_tensors (Dict[_ShardId, torch.Tensor]) – dict allowing to map ShardedTensor from the sharded_state_dict to loaded tensors.
- Returns:
None
- apply_loading_parallelization(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
Distributes the load across ranks by exchanging metadata.
Exchanges metadata from the state dict and computes the uniform (as close as possible) distribution of loads among the ranks. Marks ShardedTensors to be loaded by the current rank with replica_id 0 (and others with non 0 values).
If
self.do_cache_distributionis True, caches the distribution between the calls and subsequent distributions happen without any inter-rank communication.- Parameters:
sharded_state_dict (ShardedStateDict) – state dict to distribute the loading
- Returns:
the computed loading distribution
- Return type:
ShardDistribution (optional)
- property can_handle_sharded_objects#
- load_tensors_metadata(checkpoint_dir: pathlib.Path)#
- load_sharded_metadata(checkpoint_dir: pathlib.Path)#
- check_backend_compatibility(loaded_version)#
- check_version_compatibility(loaded_version)#
- core.dist_checkpointing.strategies.fully_parallel.distribute_main_replicas_with_precomputed_distribution(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- parallelization_group: torch.distributed.ProcessGroup,
- precomputed_distribution: Optional[megatron.core.dist_checkpointing.exchange_utils.ShardDistribution],
Applies the save distribution computed with
determine_main_replica_uniform_distribution.Based on rank assignment, sets replica ids of the shards saved by current rank to 0 and all the other replica ids to 1.
- Parameters:
sharded_state_dict (ShardedStateDict) – state dict to apply the save distribution to
parallelization_group (ProcessGroup) – distribution will be applied within this process group. Must match with the process group passed to
determine_main_replica_uniform_distribution.precomputed_distribution (ShardDistribution) – distribution computed with
determine_main_replica_uniform_distribution
Returns: None
Example replica ids of tensors A, B, C before distribution: rank0: A: (0, 0, 0), B: (0, 0, 0), C: (0, 0, 0) rank1: A: (0, 0, 1), B: (0, 0, 1), C: (0, 0, 1) rank2: A: (0, 0, 2), B: (0, 0, 2), C: (0, 0, 2)
Replicas after distribution for the example above: rank0: A: 0, B: 1, C: 1 rank1: A: 1, B: 0, C: 1 rank2: A: 1, B: 1, C: 0
- core.dist_checkpointing.strategies.fully_parallel._defer_loading_sharded_items(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- item_type: type,
- shard_id_func: Callable[[core.dist_checkpointing.strategies.fully_parallel.T], megatron.core.dist_checkpointing.utils._ShardId],
Divides state dict into parts loaded by this vs other ranks.
- Parameters:
sharded_state_dict (ShardedStateDict) – state dict with sharded items that will be divided.
item_type – The type of sharded item (ShardedObject or ShardedTensor)
shard_id_func – Function to get the shard ID for the item type
Returns: a tuple of: - ShardedStateDict: sub-state dict only with sharded items - ShardedStateDict: sub-state dict with non-sharded items - Dict[_ShardId, T]: mapping from shard id to items loaded by this rank - Dict[_ShardId, T]: mapping from shard id to items loaded by other ranks
- core.dist_checkpointing.strategies.fully_parallel._fill_in_deferred_sharded_items(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- loaded_items: Dict[megatron.core.dist_checkpointing.utils._ShardId, Any],
- item_type: type,
- shard_id_func: Callable[[core.dist_checkpointing.strategies.fully_parallel.T], megatron.core.dist_checkpointing.utils._ShardId],
Helper function to fill in items not loaded by current rank.