core.dist_checkpointing.strategies.tensorstore#

Strategies using TensorStore to load and save Zarr arrays.

Module Contents#

Classes#

TensorStoreLoadShardedStrategy

Load strategy for Zarr backend using tensorstore for loading.

Functions#

register_default_tensorstore_strategies

Register default strategies leveraging tensorstore.

merge_global_slice_with_shape

Intersects the global slice with the actual shape (prevent overflow).

_load_from_array

_load_regular_chunk

open_ts_array

Opens a Zarr file array with Tensorstore with basic setting.

Data#

API#

core.dist_checkpointing.strategies.tensorstore.logger#

‘getLogger(…)’

core.dist_checkpointing.strategies.tensorstore.register_default_tensorstore_strategies()#

Register default strategies leveraging tensorstore.

class core.dist_checkpointing.strategies.tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device: bool = False)#

Bases: core.dist_checkpointing.strategies.base.LoadShardedStrategy

Load strategy for Zarr backend using tensorstore for loading.

Initialization

load(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#
load_tensors_metadata(checkpoint_dir: Union[str, pathlib.Path])#
check_backend_compatibility(loaded_version)#
check_version_compatibility(loaded_version)#
core.dist_checkpointing.strategies.tensorstore.merge_global_slice_with_shape(global_slice, actual_shape, key)#

Intersects the global slice with the actual shape (prevent overflow).

core.dist_checkpointing.strategies.tensorstore._load_from_array(
sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor,
checkpoint_dir: pathlib.Path,
load_directly_on_device: bool = False,
apply_flattened_range: bool = True,
)#
core.dist_checkpointing.strategies.tensorstore._load_regular_chunk(
sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor,
checkpoint_dir: pathlib.Path,
)#
core.dist_checkpointing.strategies.tensorstore.open_ts_array(arr_path: pathlib.Path)#

Opens a Zarr file array with Tensorstore with basic setting.

Parameters:

arr_path (Path) – path to a Zarr (Tensorstore) array