core.dist_checkpointing.strategies.checkpointable#

Module Contents#

Classes#

CheckpointableShardedTensor

ShardedTensor extension compatible with PyTorch DCP checkpointing library.

LocalShardsContainer

DCP compatible container for local shards.

API#

class core.dist_checkpointing.strategies.checkpointable.CheckpointableShardedTensor(
data: torch.Tensor,
sh_ten: core.dist_checkpointing.mapping.ShardedTensor,
)#

Bases: torch.Tensor

ShardedTensor extension compatible with PyTorch DCP checkpointing library.

Implements the torch.distributed._checkpointable._Checkpointable protocol.

Initialization

__new__(
data: torch.Tensor,
sh_ten: core.dist_checkpointing.mapping.ShardedTensor,
)#
__create_write_items__(
fqn: str,
sh_ten: core.dist_checkpointing.strategies.checkpointable.CheckpointableShardedTensor,
index: int = None,
) list[torch.distributed.checkpoint.planner.WriteItem]#

Simple translation from ShardedTensor offsets into DCP offsets.

Parameters:
  • fqn (str) – tensor FQN.

  • sh_ten (CheckpointableShardedTensor) – same as self

  • index (int) – specifies index within the LocalShardsContainer. This is an optimization hint used in DCP.

Returns:

list of DCP WriteItem metadata objects.

Return type:

List[WriteItem]

__create_chunk_list__() list[torch.distributed.checkpoint.metadata.ChunkStorageMetadata]#

Simple translation from ShardedTensor offsets into DCP offsets.

Returns:

list of DCP ChunkStorageMetadata metadata objects.

Return type:

List[ChunkStorageMetadata]

__get_tensor_shard__(
index: torch.distributed.checkpoint.metadata.MetadataIndex,
) torch.Tensor#

Trivial implementation which simply yields the underlying tensor.

Parameters:

index (MetadataIndex) – unused

Returns:

the underlying data tensor

Return type:

Tensor

classmethod from_sh_ten(
sh_ten: core.dist_checkpointing.mapping.ShardedTensor,
) core.dist_checkpointing.strategies.checkpointable.CheckpointableShardedTensor#

Constructor which turns a ShardedTensor into CheckpointableShardedTensor

Parameters:

sh_ten (ShardedTensor) – a sharded tensor to wrap

Returns:

wrapped ShardedTensor

Return type:

CheckpointableShardedTensor

abstractmethod classmethod __torch_dispatch__(func, types, args, kwargs=None)#

Placeholder implementation.

__repr__()#
class core.dist_checkpointing.strategies.checkpointable.LocalShardsContainer(local_shards: list[torch.Tensor])#

Bases: torch.Tensor

DCP compatible container for local shards.

PyTorch DCP requires a single tensor per rank for a given global tensor FQN. This class acts as a container allowing multiple checkpointable shards per rank.

Implements the torch.distributed._checkpointable._Checkpointable protocol.

Initialization

__new__(
local_shards: list[torch.Tensor],
) core.dist_checkpointing.strategies.checkpointable.LocalShardsContainer#
abstractmethod classmethod __torch_dispatch__(func, types, args=(), kwargs=None)#

Placeholder implementation.

__create_write_items__(
fqn: str,
local_shards_cont: core.dist_checkpointing.strategies.checkpointable.LocalShardsContainer,
) list[object]#

Delegates creating write items to local shards.

Parameters:
Returns:

list of DCP WriteItem metadata objects.

Return type:

List[WriteItem]

__create_chunk_list__() list[torch.distributed.checkpoint.metadata.ChunkStorageMetadata]#

Delegates creating chunk items to local shards.

Returns:

list of DCP ChunkStorageMetadata metadata objects.

Return type:

List[ChunkStorageMetadata]

__get_tensor_shard__(
index: torch.distributed.checkpoint.metadata.MetadataIndex,
) torch.Tensor#

Performs shard matching lookup based on index hint or offset.

Parameters:

index (MetadataIndex) – metadata specifying the offset of the queried shard. Optionally provides an index hint which speeds up the lookup.

Returns:

the matching shard data tensor

Return type:

Tensor

__repr__()#