core.dist_checkpointing.strategies.two_stage#

2-stage checkpoint loading.

Module Contents#

Classes#

_ShardedTensorMetadata

TwoStageDataParallelLoadShardedStrategy

Loads one checkpoint replica from storage and broadcasts to other nodes.

Functions#

timed

Timing decorator.

sharded_tensor_chunk_id

Id of a sharded tensor.

Data#

API#

core.dist_checkpointing.strategies.two_stage._import_trigger#

None

core.dist_checkpointing.strategies.two_stage.timers#

‘defaultdict(…)’

core.dist_checkpointing.strategies.two_stage.logger#

‘getLogger(…)’

core.dist_checkpointing.strategies.two_stage.timed(verbose=True)#

Timing decorator.

class core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata#
global_rank: int#

None

sharded_tensor_no_data: core.dist_checkpointing.mapping.ShardedTensor#

None

dist_group_rank: Tuple[int]#

None

dist_group_ranks: Tuple[int]#

None

data_size: Optional[int]#

None

core.dist_checkpointing.strategies.two_stage.sharded_tensor_chunk_id(
sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor,
)#

Id of a sharded tensor.

class core.dist_checkpointing.strategies.two_stage.TwoStageDataParallelLoadShardedStrategy(
data_parallel_group,
cpu_transfer=True,
)#

Bases: core.dist_checkpointing.strategies.base.LoadShardedStrategy

Loads one checkpoint replica from storage and broadcasts to other nodes.

This strategy loads checkpoint from storage on minimal set of nodes and distributes the checkpoint to other nodes with torch.distributed. Loading is performed with tensorstore.

Steps: 0. (optional) create Gloo distributed groups

  1. Exchange ShardedTensors metadata between all nodes

  2. Align needed tensors within DP groups

  3. For each globally unique tensor: 3.a) on one of the ranks load it from storage to CPU and move to CUDA 3.b) allocate CUDA tensor on other ranks 3.c) broadcast within DP group 3.d) copy tensor content to the model param location 3.e) free tensor buffers from a) and b)

Notes:

  1. Loading and broadcasting is done sequentially to avoid both host and device OOMs

  2. There is a lot of overlap potential between all three steps done for each tensor: 2.a) loading from storage to numpy 2.b) moving CPU tensors to CUDA 2.c) broadcast

Initialization

load(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: pathlib.Path,
)#

Main load method.

summarize_load_times()#

Summarize load times.

load_tensor_from_storage(
checkpoint_dir,
ten_meta: core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata,
)#

Load tensor from storage.

maybe_init_gloo_group()#

Create Gloo groups.

check_backend_compatibility(loaded_version)#
check_version_compatibility(loaded_version)#
_build_load_plan(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
) List[core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata]#
deduplicate_chunks(
ten_metas: List[core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata],
)#

Group tensors by chunk and then pick the tensor with the lowest rank.

NOTE: with proper loading overlap, loading from randomized ranks (instead of the smallest one) could be beneficial here.

_exchange_loaded_tensors(
ten_metas: List[core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata],
sharded_state_dict,
checkpoint_dir,
)#
_distribute_data_to_state_dict(
ten_meta: core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata,
loaded_ten: torch.Tensor,
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
)#
load_tensors_metadata(checkpoint_dir: pathlib.Path)#