core.dist_checkpointing.strategies.zarr#

Strategies using Zarr as an underlying format.

Module Contents#

Classes#

ZarrSaveShardedStrategy

Save strategy for Zarr backend.

ZarrLoadShardedStrategy

Load strategy for the Zarr backend.

Functions#

register_default_zarr_strategies

Register default strategies related to Zarr backend.

_create_or_open_zarr_arrays

Returns list of zarr arrays corresponding to given tensors.

_should_create_array

_save_to_existing_array

_create_zarr_array

_load_from_array

_open_zarr_array_verbose

postprocess_numpy_array

Turn numpy array to torch tensor.

flatten_range

Apply flattened range to a tensor.

pad_to_expected_shape

Pad tensor to the expected shape.

load_zarr_based_sharded_metadata

Load metadata of Zarr arrays.

Data#

API#

core.dist_checkpointing.strategies.zarr.logger#

‘getLogger(…)’

core.dist_checkpointing.strategies.zarr.numpy_to_torch_dtype_dict#

None

core.dist_checkpointing.strategies.zarr.torch_to_numpy_dtype_dict#

None

core.dist_checkpointing.strategies.zarr.register_default_zarr_strategies()#

Register default strategies related to Zarr backend.

class core.dist_checkpointing.strategies.zarr.ZarrSaveShardedStrategy(backend: str, version: int)#

Bases: core.dist_checkpointing.strategies.base.SaveShardedStrategy

Save strategy for Zarr backend.

Initialization

save(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#
core.dist_checkpointing.strategies.zarr._create_or_open_zarr_arrays(
sharded_tensors: List[core.dist_checkpointing.mapping.ShardedTensor],
checkpoint_dir: pathlib.Path,
) List[Optional[zarr.Array]]#

Returns list of zarr arrays corresponding to given tensors.

For a sharded tensors that: a) is main replica and represents the first chunk (all offsets 0), creates the Zarr array b) is main replica but not the first chunk, opens the arrays created in (a) (possibly by other process) c) otherwise, sets the corresponding array to None since it won’t be used

Parameters:
  • sharded_tensors (List[ShardedTensor]) – sharded tensors from a given rank that will be saved to checkpoint

  • checkpoint_dir (Path) – checkpoint in which the arrays will be created

core.dist_checkpointing.strategies.zarr._should_create_array(
ten: core.dist_checkpointing.mapping.ShardedTensor,
)#
core.dist_checkpointing.strategies.zarr._save_to_existing_array(
sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor,
arr: Optional[zarr.Array],
)#
core.dist_checkpointing.strategies.zarr._create_zarr_array(
sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor,
checkpoint_dir: pathlib.Path,
)#
class core.dist_checkpointing.strategies.zarr.ZarrLoadShardedStrategy(backend: str, version: int)#

Bases: core.dist_checkpointing.strategies.base.LoadShardedStrategy

Load strategy for the Zarr backend.

Initialization

load(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#
load_tensors_metadata(checkpoint_dir: Union[str, pathlib.Path])#
check_backend_compatibility(loaded_version)#
check_version_compatibility(loaded_version)#
core.dist_checkpointing.strategies.zarr._load_from_array(
sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor,
checkpoint_dir: pathlib.Path,
)#
core.dist_checkpointing.strategies.zarr._open_zarr_array_verbose(path: pathlib.Path, mode: str, **open_kwargs)#
core.dist_checkpointing.strategies.zarr.postprocess_numpy_array(
loaded_array,
sharded_tensor,
apply_flattened_range=True,
)#

Turn numpy array to torch tensor.

core.dist_checkpointing.strategies.zarr.flatten_range(sharded_tensor, x)#

Apply flattened range to a tensor.

core.dist_checkpointing.strategies.zarr.pad_to_expected_shape(
x: torch.Tensor,
expected_sharded_ten: core.dist_checkpointing.mapping.ShardedTensor,
)#

Pad tensor to the expected shape.

core.dist_checkpointing.strategies.zarr.load_zarr_based_sharded_metadata(
checkpoint_dir: pathlib.Path,
get_shape_dtype_fn: Callable[[str], Tuple[Tuple[int], numpy.dtype]],
) core.dist_checkpointing.mapping.ShardedStateDict#

Load metadata of Zarr arrays.

Parameters:
  • checkpoint_dir (str) – checkpoint root directory

  • get_shape_dtype_fn (str -> ((int, ...), np.dtype)) – a function returning an array shape and dtype for a given Zarr array path