core.dist_checkpointing.strategies.tensorstore#
Strategies using TensorStore to load and save Zarr arrays.
Module Contents#
Classes#
Load strategy for Zarr backend using |
Functions#
Register default strategies leveraging tensorstore. |
|
Intersects the global slice with the actual shape (prevent overflow). |
|
Opens a Zarr file array with Tensorstore with basic setting. |
Data#
API#
- core.dist_checkpointing.strategies.tensorstore.logger#
‘getLogger(…)’
- core.dist_checkpointing.strategies.tensorstore.register_default_tensorstore_strategies()#
Register default strategies leveraging tensorstore.
- class core.dist_checkpointing.strategies.tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device: bool = False)#
Bases:
core.dist_checkpointing.strategies.base.LoadShardedStrategyLoad strategy for Zarr backend using
tensorstorefor loading.Initialization
- load(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: Union[str, pathlib.Path],
- load_tensors_metadata(checkpoint_dir: Union[str, pathlib.Path])#
- check_backend_compatibility(loaded_version)#
- check_version_compatibility(loaded_version)#
- core.dist_checkpointing.strategies.tensorstore.merge_global_slice_with_shape(global_slice, actual_shape, key)#
Intersects the global slice with the actual shape (prevent overflow).
- core.dist_checkpointing.strategies.tensorstore._load_from_array(
- sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor,
- checkpoint_dir: pathlib.Path,
- load_directly_on_device: bool = False,
- apply_flattened_range: bool = True,
- core.dist_checkpointing.strategies.tensorstore._load_regular_chunk(
- sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor,
- checkpoint_dir: pathlib.Path,
- core.dist_checkpointing.strategies.tensorstore.open_ts_array(arr_path: pathlib.Path)#
Opens a Zarr file array with Tensorstore with basic setting.
- Parameters:
arr_path (Path) – path to a Zarr (Tensorstore) array