core.dist_checkpointing.mapping#
Core library classes for representing sharding of tensors and objects.
The main expected usage is wrapping torch.Tensors in state dicts with ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod).
Module Contents#
Classes#
Base class for ShardedTensor and ShardedStateDict. |
|
Represents a mapping between a local tensor and a global tensor. |
|
Object that should not be stored in a checkpoint, but restored locally. |
|
Represents a mapping between a local object and a global object. |
|
Allows to apply transformations to tensors before/after serialization. |
Functions#
Checks if given |
|
Turn ShardedTensorFactories into ShardedTensors in-place. |
|
Apply merges defined by ShardedTensorFactories in-place. |
Data#
API#
- core.dist_checkpointing.mapping.logger#
‘getLogger(…)’
- core.dist_checkpointing.mapping.StateDict#
None
- core.dist_checkpointing.mapping.CommonStateDict#
None
- core.dist_checkpointing.mapping.ShardedStateDict#
None
- core.dist_checkpointing.mapping.ReplicaId#
None
- core.dist_checkpointing.mapping._logged_deprecations#
None
- class core.dist_checkpointing.mapping.ShardedBase#
Bases:
abc.ABCBase class for ShardedTensor and ShardedStateDict.
- key: str#
None
- data: object#
None
- replica_id: core.dist_checkpointing.mapping.ReplicaId#
None
- abstractmethod validate_metadata_integrity()#
Codifies the constraints on metadata attributes.
- abstractmethod without_data() core.dist_checkpointing.mapping.ShardedBase#
Returns a new ShardedBase instance with data=None.
- class core.dist_checkpointing.mapping.ShardedTensor#
Bases:
core.dist_checkpointing.mapping.ShardedBaseRepresents a mapping between a local tensor and a global tensor.
Global tensor is assumed to consist of many local tensors distributed between different processes.
- Parameters:
key – unique identifier of a global tensor
data – local tensor data. Can be None only for consistency validation
dtype – tensor dtype
local_shape – local tensor shape
global_shape – global tensor shape
global_offset – offset of a local tensor in a global tensor, specified in number of tensor elements
axis_fragmentations – global tensor fragmentation of each axis
replica_id – indicates given local tensor’s replication wrt. local tensors in different processes
prepend_axis_num – number of axes prepended to the local tensor to reflect global tensor shape. The behavior is similar to unsqueezing the local tensor.
allow_shape_mismatch – if True, during loading, the global shape of a stored tensor does not have to match the expected global shape. Useful for representing tensors with flexible shape, e.g. padded.
flattened_range – specifies a slice that should be applied to a flattened tensor with
local_shapein order to get the tensor stored asdata
- key: str#
None
- data: Optional[torch.Tensor]#
‘field(…)’
- dtype: torch.dtype#
None
- local_shape: Tuple[int, ...]#
None
- global_shape: Tuple[int, ...]#
None
- global_offset: Tuple[int, ...]#
None
- axis_fragmentations: Optional[Tuple[int, ...]]#
None
- replica_id: core.dist_checkpointing.mapping.ReplicaId#
0
- prepend_axis_num: int#
0
- allow_shape_mismatch: bool#
False
- flattened_range: Optional[slice]#
None
- __post_init__()#
- validate_metadata_integrity() None#
Codifies the constraints on metadata attributes.
Meeting those constraints is guaranteed when instantiating a ShardedTensor class with
from_rank_offsetsorfrom_rank_offsets_flatconstructors.- Returns:
None
- property has_regular_grid#
Alias for having a regular sharding grid.
- global_slice() Tuple[Union[int, slice], ...]#
Returns a tuple of int and slice objects representing a slice of the global tensor that this ShardedTensor corresponds to.
- local_chunk_offset_in_global() Tuple[int, ...]#
Offset of a local chunk in a global array of chunks.
- Returns:
the offset of the whole local chunk in a global array of chunks.
- Return type:
Tuple[int, …]
- max_allowed_chunks() Tuple[int, ...]#
Returns the maximum allowed chunks for this ShardedTensor.
- without_data()#
- classmethod from_rank_offsets(
- key: str,
- data: torch.Tensor,
- *rank_offsets: Tuple[int, int, int],
- replica_id: core.dist_checkpointing.mapping.ReplicaId = 0,
- prepend_axis_num: int = 0,
- flattened_range: None = None,
- **init_kwargs,
Allows to construct the ShardedTensor given offset specified in process ranks.
- Parameters:
key (str) – unique key
data (torch.Tensor) – local tensor data
rank_offsets (Tuple[int, int, int]) – each tuple (axis, axis_rank_offset, axis_fragm) says that if global tensor is divided into
axis_fragmfragment alongaxisaxis, then local tensor data corresponds to theaxis_rank_offsetchunk.replica_id (ReplicaId) – see ShardedTensor
prepend_axis_num (int) – see ShardedTensor
flattened_range (None) – must be None when using this constructor
init_kwargs – passed to ShardedTensor.init
- init_data(
- device: Union[str, torch.device],
- init_fn=torch.empty,
Initialize the tensor data of this ShardedTensor.
Only called if
dataattribute is None.- Parameters:
device (Union[str, torch.device]) – device to place the tensor on
init_fn (Callable, optional) – function to use to initialize the tensor. Defaults to
torch.empty.
- narrow(
- dim: int,
- start: int,
- length: int,
This is an analogue of torch.narrow for ShardedTensors.
Narrowing assumes that we narrow a local tensor on each rank. This has consequences on local_shape, global_shape, global_offset, etc.
- Parameters:
dim (int) – dimension to narrow. Doesn’t include prepended axes.
start (int) – start element
length (int) – length of the slice
- Returns:
narrowed ShardedTensors. For non-flat tensors, the list will always have 1 element. For flat ShardedTensors the number of elements varies depending on
dimand on overlap, because flat tensors must be contiguous. In particular the list can be empty.- Return type:
List[ShardedTensor]
- core.dist_checkpointing.mapping.is_main_replica(replica_id: core.dist_checkpointing.mapping.ReplicaId)#
Checks if given
replica_idis considered as main.“Main” replica is:
integer 0
or an iterable with all 0 elements
It is the application responsibility to set correct replicas for sharded tensors.
- Parameters:
replica_id (Union[int, Tuple[int, ...]]) – replica id
- Returns:
True for a “main” replica
- Return type:
(bool)
- class core.dist_checkpointing.mapping.LocalNonpersistentObject(obj)#
Object that should not be stored in a checkpoint, but restored locally.
Wrapping any object inside the state dict with LocalNonpersistentObject will result in:
during saving, this object will not be stored in the checkpoint
during loading, a local version of this object will be placed in a state dict
Initialization
- unwrap()#
Returns the original object.
- class core.dist_checkpointing.mapping.ShardedObject#
Bases:
core.dist_checkpointing.mapping.ShardedBaseRepresents a mapping between a local object and a global object.
Global object is assumed to consist of many local objects distributed between different processes.
NOTE: Contrary to ShardedTensor, it’s impossible to change global object sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor with atomic arbitrary typed elements.
- Parameters:
key – unique identifier of a global tensor
data – local object data. Can be None only for consistency validation
global_shape – global object shape
global_offset – offset of a local object in a global object, specified in number of shards
replica_id – indicates local object replication wrt. local objects in different processes
- key: str#
None
- data: object#
None
- global_shape: Tuple[int, ...]#
None
- global_offset: Tuple[int, ...]#
None
- replica_id: core.dist_checkpointing.mapping.ReplicaId#
0
- __post_init__()#
- validate_metadata_integrity()#
- without_data()#
- property unique_key#
returns a unique key for this object
- __str__()#
- classmethod empty_from_unique_key(
- unique_key,
- replica_id: core.dist_checkpointing.mapping.ReplicaId = 0,
Instantiates a ShardedObject from a unique key.
- Parameters:
unique_key – a string of the form
/shard_<global_offset>_<global_shape> replica_id – indicates local object replication wrt. local objects in different processes
- Returns:
a ShardedObject with data=None
- core.dist_checkpointing.mapping.FactoryBuildFn#
None
- core.dist_checkpointing.mapping.FactoryMergeFn#
None
- class core.dist_checkpointing.mapping.ShardedTensorFactory#
Bases:
core.dist_checkpointing.mapping.ShardedBaseAllows to apply transformations to tensors before/after serialization.
The essence of those transformations is that they can be applied to optimizer states the same way they are applied to the model params. The ultimate state dict with sharded tensors must depend functionally on
build_fnarguments (key, data, replica_id, flattened_range), which will be provided by the optimizer.Builder creates a sub-state-dict out of a tensor before saving, and merger merges the corresponding state dict after loading.
- Parameters:
key (str) – unique identifier of the factory
data (torch.Tensor) – original model parameter that will be further transformed by this factory
build_fn (callable) – function that transforms the original tensor to a sharded state dict
merge_fn (callable) – function that transforms loaded subtree back into a single tensor (inverse of
build_fn)replica_id (ReplicaId) – indicates factory replication wrt. factories in different processes
flattened_range (slice, optional) – indicates additional flattening applied to the ShardedTensors produced by the factory
- key: str#
None
- data: torch.Tensor#
None
- build_fn: core.dist_checkpointing.mapping.FactoryBuildFn#
None
- merge_fn: core.dist_checkpointing.mapping.FactoryMergeFn#
None
- replica_id: core.dist_checkpointing.mapping.ReplicaId#
0
- flattened_range: Optional[slice]#
None
- build()#
Builds a ShardedStateDict from the original tensor
- validate_metadata_integrity()#
No reasonable checks can be applied
- without_data()#
- core.dist_checkpointing.mapping.apply_factories(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
Turn ShardedTensorFactories into ShardedTensors in-place.
- Parameters:
sharded_state_dict (ShardedStateDict) – state dict possibly containing ShardedTensorFactory objects
- Returns:
state dict is modified in place
- Return type:
None
- core.dist_checkpointing.mapping.apply_factory_merges(
- x1: core.dist_checkpointing.mapping.StateDict,
- x2: core.dist_checkpointing.mapping.ShardedStateDict,
- key: Tuple[str, ...] = (),
Apply merges defined by ShardedTensorFactories in-place.
- Parameters:
x1 (StateDict) – state dict loaded from the checkpoint
x2 (ShardedStateDict) – subset of
x1(in terms of dict keys) with ShardedTensorFactory as (possibly nested) values that define how to merge objects from thex1state dictkey (Tuple[str, ...]) – current key in a recursive call. Used only for reporting meaningful errors
- Returns:
x1modified in-place- Return type:
StateDict