)

Bases: LoadShardedStrategy

Loads one checkpoint replica from storage and broadcasts to other nodes.

This strategy loads checkpoint from storage on minimal set of nodes and distributes the checkpoint to other nodes with torch.distributed. Loading is performed with tensorstore.

Steps: 0. (optional) create Gloo distributed groups 1. Exchange ShardedTensors metadata between all nodes 2. Align needed tensors within DP groups 3. For each globally unique tensor: 3.a) on one of the ranks load it from storage to CPU and move to CUDA 3.b) allocate CUDA tensor on other ranks 3.c) broadcast within DP group 3.d) copy tensor content to the model param location 3.e) free tensor buffers from a) and b)

Notes: 1. Loading and broadcasting is done sequentially to avoid both host and device OOMs 2. There is a lot of overlap potential between all three steps done for each tensor: 2.a) loading from storage to numpy 2.b) moving CPU tensors to CUDA 2.c) broadcast

check_backend_compatibility ( loaded_version , ) # Verifies if this strategy is compatible with loaded_backend .

check_version_compatibility ( loaded_version , ) # Verifies if this strategy is compatible with loaded_version .

deduplicate_chunks ( ten_metas : List [ _ShardedTensorMetadata ] , ) # Group tensors by chunk and then pick the tensor with the lowest rank. NOTE: with proper loading overlap, loading from randomized ranks (instead of the smallest one) could be beneficial here.

load ( sharded_state_dict : Dict [ str , Any ] , checkpoint_dir : Path , ) # Main load method.

load_tensor_from_storage ( checkpoint_dir , ten_meta : _ShardedTensorMetadata , ) # Load tensor from storage.

load_tensors_metadata ( checkpoint_dir : Path , ) # Load tensors metadata from the checkpoint for ShardedTensors. Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys). Dict values are ShardedTensors without any data and sharding (so, the only useful information is tensors global shape and dtype).

maybe_init_gloo_group ( ) # Create Gloo groups.