core.dist_checkpointing.strategies.checkpointable#
Module Contents#
Classes#
ShardedTensor extension compatible with PyTorch DCP checkpointing library. |
|
DCP compatible container for local shards. |
API#
- class core.dist_checkpointing.strategies.checkpointable.CheckpointableShardedTensor(
- data: torch.Tensor,
- sh_ten: core.dist_checkpointing.mapping.ShardedTensor,
Bases:
torch.TensorShardedTensor extension compatible with PyTorch DCP checkpointing library.
Implements the torch.distributed._checkpointable._Checkpointable protocol.
Initialization
- __new__(
- data: torch.Tensor,
- sh_ten: core.dist_checkpointing.mapping.ShardedTensor,
- __create_write_items__(
- fqn: str,
- sh_ten: core.dist_checkpointing.strategies.checkpointable.CheckpointableShardedTensor,
- index: int = None,
Simple translation from ShardedTensor offsets into DCP offsets.
- Parameters:
fqn (str) – tensor FQN.
sh_ten (CheckpointableShardedTensor) – same as
selfindex (int) – specifies index within the LocalShardsContainer. This is an optimization hint used in DCP.
- Returns:
list of DCP WriteItem metadata objects.
- Return type:
List[WriteItem]
- __create_chunk_list__() list[torch.distributed.checkpoint.metadata.ChunkStorageMetadata]#
Simple translation from ShardedTensor offsets into DCP offsets.
- Returns:
list of DCP ChunkStorageMetadata metadata objects.
- Return type:
List[ChunkStorageMetadata]
- __get_tensor_shard__(
- index: torch.distributed.checkpoint.metadata.MetadataIndex,
Trivial implementation which simply yields the underlying tensor.
- Parameters:
index (MetadataIndex) – unused
- Returns:
the underlying data tensor
- Return type:
Tensor
- classmethod from_sh_ten( ) core.dist_checkpointing.strategies.checkpointable.CheckpointableShardedTensor#
Constructor which turns a ShardedTensor into CheckpointableShardedTensor
- Parameters:
sh_ten (ShardedTensor) – a sharded tensor to wrap
- Returns:
wrapped ShardedTensor
- Return type:
- abstractmethod classmethod __torch_dispatch__(func, types, args, kwargs=None)#
Placeholder implementation.
- __repr__()#
- class core.dist_checkpointing.strategies.checkpointable.LocalShardsContainer(local_shards: list[torch.Tensor])#
Bases:
torch.TensorDCP compatible container for local shards.
PyTorch DCP requires a single tensor per rank for a given global tensor FQN. This class acts as a container allowing multiple checkpointable shards per rank.
Implements the torch.distributed._checkpointable._Checkpointable protocol.
Initialization
- __new__(
- local_shards: list[torch.Tensor],
- abstractmethod classmethod __torch_dispatch__(func, types, args=(), kwargs=None)#
Placeholder implementation.
- __create_write_items__(
- fqn: str,
- local_shards_cont: core.dist_checkpointing.strategies.checkpointable.LocalShardsContainer,
Delegates creating write items to local shards.
- Parameters:
fqn (str) – tensor FQN.
local_shards_cont (LocalShardsContainer) – same as
self
- Returns:
list of DCP WriteItem metadata objects.
- Return type:
List[WriteItem]
- __create_chunk_list__() list[torch.distributed.checkpoint.metadata.ChunkStorageMetadata]#
Delegates creating chunk items to local shards.
- Returns:
list of DCP ChunkStorageMetadata metadata objects.
- Return type:
List[ChunkStorageMetadata]
- __get_tensor_shard__(
- index: torch.distributed.checkpoint.metadata.MetadataIndex,
Performs shard matching lookup based on index hint or offset.
- Parameters:
index (MetadataIndex) – metadata specifying the offset of the queried shard. Optionally provides an index hint which speeds up the lookup.
- Returns:
the matching shard data tensor
- Return type:
Tensor
- __repr__()#