What can I help you with?
Megatron Core User Guide

dist_checkpointing.strategies package

Package defining different checkpoint formats (backends) and saving/loading algorithms (strategies).

Strategies can be used for implementing new checkpoint formats or implementing new (more optimal for a given use case) ways of saving/loading of existing formats. Strategies are passed to dist_checkpointing.load and dist_checkpointing.save functions and control the actual saving/loading procedure.

Strategies base interfaces.

class core.dist_checkpointing.strategies.base.AsyncSaveShardedStrategy(backend: str, version: int)

Bases: core.dist_checkpointing.strategies.base.SaveShardedStrategy

Save strategy suitable for async save.

abstract async_save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path) → core.dist_checkpointing.strategies.async_utils.AsyncRequest

Perform preparation and return an AsyncRequest to the external caller.

Parameters
  • sharded_state_dict (ShardedStateDict) – sharded state dict to save

  • checkpoint_dir (Path) – checkpoint target directory

Returns

represents the async save function and finalization function.

It is the caller responsibility to actually schedule the async save.

Return type

AsyncRequest

save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

Each async strategy can be trivially used as a sync strategy.

class core.dist_checkpointing.strategies.base.LoadCommonStrategy

Bases: core.dist_checkpointing.strategies.base.LoadStrategyBase

Load strategy for common (non-sharded) objects

abstract load_common(checkpoint_dir: pathlib.Path)

Load common part of the checkpoint.

load_sharded_metadata(checkpoint_dir: pathlib.Path) → Dict[str, Any]

Load just the metadata from the checkpoint.

abstract load_sharded_objects(sharded_objects_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

Load sharded objects from the checkpoint.

class core.dist_checkpointing.strategies.base.LoadShardedStrategy

Bases: core.dist_checkpointing.strategies.base.LoadStrategyBase

Load strategy for sharded tensors

abstract load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

Load the sharded part of the checkpoint.

load_sharded_metadata(checkpoint_dir: pathlib.Path)

Load sharded metadata from the checkpoint for ShardedTensors and ShardedObjects.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply sharded keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors or ShardedObjects without any data and sharding.

abstract load_tensors_metadata(checkpoint_dir: pathlib.Path)

Load tensors metadata from the checkpoint for ShardedTensors.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any data and sharding (so, the only useful information is tensors global shape and dtype).

remove_sharded_tensors(checkpoint_dir: str, key_prefix: str)

Remove all tensors whose key starts with key_prefix

class core.dist_checkpointing.strategies.base.LoadStrategyBase

Bases: abc.ABC

Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version.

property can_handle_sharded_objects

Returns whether or not this strategy can handle loading ShardedObjects.

abstract check_backend_compatibility(loaded_backend)

Verifies if this strategy is compatible with loaded_backend.

abstract check_version_compatibility(loaded_version)

Verifies if this strategy is compatible with loaded_version.

class core.dist_checkpointing.strategies.base.SaveCommonStrategy(backend: str, version: int)

Bases: core.dist_checkpointing.strategies.base.SaveStrategyBase

Save strategy for common (non-sharded) objects

abstract save_common(common_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

Save common part of the state dict.

save_sharded_objects(sharded_objects_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

Save sharded objects from the state dict.

class core.dist_checkpointing.strategies.base.SaveShardedStrategy(backend: str, version: int)

Bases: core.dist_checkpointing.strategies.base.SaveStrategyBase

Save strategy for sharded tensors

abstract save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

Save the sharded part of the state dict.

class core.dist_checkpointing.strategies.base.SaveStrategyBase(backend: str, version: int)

Bases: abc.ABC

Base class for a save strategy. Requires defining a backend type and version of the saved format.

property can_handle_sharded_objects

Returns whether or not this strategy can handle saving ShardedObjects.

class core.dist_checkpointing.strategies.base.StrategyAction(value)

Bases: enum.Enum

Specifies save vs load and sharded vs common action.

LOAD_COMMON = 'load_common'
LOAD_SHARDED = 'load_sharded'
SAVE_COMMON = 'save_common'
SAVE_SHARDED = 'save_sharded'
core.dist_checkpointing.strategies.base.get_default_strategy(action: core.dist_checkpointing.strategies.base.StrategyAction, backend: str, version: int)

Retrieves a default strategy for a given action, backend and version.

core.dist_checkpointing.strategies.base.register_default_strategy(action: core.dist_checkpointing.strategies.base.StrategyAction, backend: str, version: int, strategy: Union[core.dist_checkpointing.strategies.base.SaveStrategyBase, core.dist_checkpointing.strategies.base.LoadStrategyBase])

Adds a given strategy to the registry of default strategies.

Parameters
  • action (StrategyAction) – specifies save/load and sharded/common

  • backend (str) – backend that the strategy becomes a default for

  • version (int) – version that the strategy becomes a default for

  • strategy (SaveStrategyBase, LoadStrategyBase) – strategy to register

Strategies using TensorStore to load and save Zarr arrays.

class core.dist_checkpointing.strategies.tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device: bool = False)

Bases: core.dist_checkpointing.strategies.base.LoadShardedStrategy

Load strategy for Zarr backend using tensorstore for loading.

check_backend_compatibility(loaded_version)

Verifies if this strategy is compatible with loaded_backend.

check_version_compatibility(loaded_version)

Verifies if this strategy is compatible with loaded_version.

load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

Load the sharded part of the checkpoint.

load_tensors_metadata(checkpoint_dir: pathlib.Path)

Load tensors metadata from the checkpoint for ShardedTensors.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any data and sharding (so, the only useful information is tensors global shape and dtype).

core.dist_checkpointing.strategies.tensorstore.merge_global_slice_with_shape(global_slice, actual_shape, key)

Intersects the global slice with the actual shape (prevent overflow).

core.dist_checkpointing.strategies.tensorstore.open_ts_array(arr_path: pathlib.Path)

Opens a Zarr file array with Tensorstore with basic setting.

Parameters

arr_path (Path) – path to a Zarr (Tensorstore) array

core.dist_checkpointing.strategies.tensorstore.register_default_tensorstore_strategies()

Register default strategies leveraging tensorstore.

2-stage checkpoint loading.

class core.dist_checkpointing.strategies.two_stage.TwoStageDataParallelLoadShardedStrategy(data_parallel_group, cpu_transfer=True)

Bases: core.dist_checkpointing.strategies.base.LoadShardedStrategy

Loads one checkpoint replica from storage and broadcasts to other nodes.

This strategy loads checkpoint from storage on minimal set of nodes and distributes the checkpoint to other nodes with torch.distributed. Loading is performed with tensorstore.

Steps: 0. (optional) create Gloo distributed groups 1. Exchange ShardedTensors metadata between all nodes 2. Align needed tensors within DP groups 3. For each globally unique tensor: 3.a) on one of the ranks load it from storage to CPU and move to CUDA 3.b) allocate CUDA tensor on other ranks 3.c) broadcast within DP group 3.d) copy tensor content to the model param location 3.e) free tensor buffers from a) and b)

Notes: 1. Loading and broadcasting is done sequentially to avoid both host and device OOMs 2. There is a lot of overlap potential between all three steps done for each tensor: 2.a) loading from storage to numpy 2.b) moving CPU tensors to CUDA 2.c) broadcast

check_backend_compatibility(loaded_version)

Verifies if this strategy is compatible with loaded_backend.

check_version_compatibility(loaded_version)

Verifies if this strategy is compatible with loaded_version.

deduplicate_chunks(ten_metas: List[core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata])

Group tensors by chunk and then pick the tensor with the lowest rank.

NOTE: with proper loading overlap, loading from randomized ranks

(instead of the smallest one) could be beneficial here.

load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

Load the sharded part of the checkpoint.

load_tensor_from_storage(checkpoint_dir, ten_meta: core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata)
load_tensors_metadata(checkpoint_dir: pathlib.Path)

Load tensors metadata from the checkpoint for ShardedTensors.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any data and sharding (so, the only useful information is tensors global shape and dtype).

maybe_init_gloo_group()
summarize_load_times()
core.dist_checkpointing.strategies.two_stage.sharded_tensor_chunk_id(sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor)
core.dist_checkpointing.strategies.two_stage.timed(verbose=True)

Strategies using Zarr as an underlying format.

class core.dist_checkpointing.strategies.zarr.ZarrLoadShardedStrategy

Bases: core.dist_checkpointing.strategies.base.LoadShardedStrategy

Load strategy for the Zarr backend.

check_backend_compatibility(loaded_version)

Verifies if this strategy is compatible with loaded_backend.

check_version_compatibility(loaded_version)

Verifies if this strategy is compatible with loaded_version.

load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

Load the sharded part of the checkpoint.

load_tensors_metadata(checkpoint_dir: pathlib.Path)

Load tensors metadata from the checkpoint for ShardedTensors.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any data and sharding (so, the only useful information is tensors global shape and dtype).

class core.dist_checkpointing.strategies.zarr.ZarrSaveShardedStrategy(backend: str, version: int)

Bases: core.dist_checkpointing.strategies.base.SaveShardedStrategy

Save strategy for Zarr backend.

save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

Save the sharded part of the state dict.

core.dist_checkpointing.strategies.zarr.flatten_range(sharded_tensor, x)

Apply flattened range to a tensor.

core.dist_checkpointing.strategies.zarr.load_zarr_based_sharded_metadata(checkpoint_dir: pathlib.Path, get_shape_dtype_fn: Callable[[str], Tuple[Tuple[int], numpy.dtype]]) → Dict[str, Any]

Load metadata of Zarr arrays.

Parameters
  • checkpoint_dir (str) – checkpoint root directory

  • get_shape_dtype_fn (str -> ((int, ...), np.dtype)) – a function returning an array shape and dtype for a given Zarr array path

core.dist_checkpointing.strategies.zarr.pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: core.dist_checkpointing.mapping.ShardedTensor)

Pad tensor to the expected shape.

core.dist_checkpointing.strategies.zarr.postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=True)

Turn numpy array to torch tensor.

core.dist_checkpointing.strategies.zarr.register_default_zarr_strategies()

Register default strategies related to Zarr backend.

Various loading and saving strategies

Previous dist_checkpointing package
Next Distributed Optimizer
© Copyright 2022-2025, NVIDIA. Last updated on Jan 14, 2025.