core.dist_checkpointing.mapping#

Core library classes for representing sharding of tensors and objects.

The main expected usage is wrapping torch.Tensors in state dicts with ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod).

Module Contents#

Classes#

ShardedBase

Base class for ShardedTensor and ShardedStateDict.

ShardedTensor

Represents a mapping between a local tensor and a global tensor.

LocalNonpersistentObject

Object that should not be stored in a checkpoint, but restored locally.

ShardedObject

Represents a mapping between a local object and a global object.

ShardedTensorFactory

Allows to apply transformations to tensors before/after serialization.

Functions#

is_main_replica

Checks if given replica_id is considered as main.

apply_factories

Turn ShardedTensorFactories into ShardedTensors in-place.

apply_factory_merges

Apply merges defined by ShardedTensorFactories in-place.

Data#

API#

core.dist_checkpointing.mapping.logger#

‘getLogger(…)’

core.dist_checkpointing.mapping.StateDict#

None

core.dist_checkpointing.mapping.CommonStateDict#

None

core.dist_checkpointing.mapping.ShardedStateDict#

None

core.dist_checkpointing.mapping.ReplicaId#

None

core.dist_checkpointing.mapping._logged_deprecations#

None

class core.dist_checkpointing.mapping.ShardedBase#

Bases: abc.ABC

Base class for ShardedTensor and ShardedStateDict.

key: str#

None

data: object#

None

replica_id: core.dist_checkpointing.mapping.ReplicaId#

None

abstractmethod validate_metadata_integrity()#

Codifies the constraints on metadata attributes.

abstractmethod without_data() core.dist_checkpointing.mapping.ShardedBase#

Returns a new ShardedBase instance with data=None.

class core.dist_checkpointing.mapping.ShardedTensor#

Bases: core.dist_checkpointing.mapping.ShardedBase

Represents a mapping between a local tensor and a global tensor.

Global tensor is assumed to consist of many local tensors distributed between different processes.

Parameters:
  • key – unique identifier of a global tensor

  • data – local tensor data. Can be None only for consistency validation

  • dtype – tensor dtype

  • local_shape – local tensor shape

  • global_shape – global tensor shape

  • global_offset – offset of a local tensor in a global tensor, specified in number of tensor elements

  • axis_fragmentations – global tensor fragmentation of each axis

  • replica_id – indicates given local tensor’s replication wrt. local tensors in different processes

  • prepend_axis_num – number of axes prepended to the local tensor to reflect global tensor shape. The behavior is similar to unsqueezing the local tensor.

  • allow_shape_mismatch – if True, during loading, the global shape of a stored tensor does not have to match the expected global shape. Useful for representing tensors with flexible shape, e.g. padded.

  • flattened_range – specifies a slice that should be applied to a flattened tensor with local_shape in order to get the tensor stored as data

key: str#

None

data: Optional[torch.Tensor]#

‘field(…)’

dtype: torch.dtype#

None

local_shape: Tuple[int, ...]#

None

global_shape: Tuple[int, ...]#

None

global_offset: Tuple[int, ...]#

None

axis_fragmentations: Optional[Tuple[int, ...]]#

None

replica_id: core.dist_checkpointing.mapping.ReplicaId#

0

prepend_axis_num: int#

0

allow_shape_mismatch: bool#

False

flattened_range: Optional[slice]#

None

__post_init__()#
validate_metadata_integrity() None#

Codifies the constraints on metadata attributes.

Meeting those constraints is guaranteed when instantiating a ShardedTensor class with from_rank_offsets or from_rank_offsets_flat constructors.

Returns:

None

property has_regular_grid#

Alias for having a regular sharding grid.

global_slice() Tuple[Union[int, slice], ...]#

Returns a tuple of int and slice objects representing a slice of the global tensor that this ShardedTensor corresponds to.

local_chunk_offset_in_global() Tuple[int, ...]#

Offset of a local chunk in a global array of chunks.

Returns:

the offset of the whole local chunk in a global array of chunks.

Return type:

Tuple[int, …]

max_allowed_chunks() Tuple[int, ...]#

Returns the maximum allowed chunks for this ShardedTensor.

without_data()#
classmethod from_rank_offsets(
key: str,
data: torch.Tensor,
*rank_offsets: Tuple[int, int, int],
replica_id: core.dist_checkpointing.mapping.ReplicaId = 0,
prepend_axis_num: int = 0,
flattened_range: None = None,
**init_kwargs,
)#

Allows to construct the ShardedTensor given offset specified in process ranks.

Parameters:
  • key (str) – unique key

  • data (torch.Tensor) – local tensor data

  • rank_offsets (Tuple[int, int, int]) – each tuple (axis, axis_rank_offset, axis_fragm) says that if global tensor is divided into axis_fragm fragment along axis axis, then local tensor data corresponds to the axis_rank_offset chunk.

  • replica_id (ReplicaId) – see ShardedTensor

  • prepend_axis_num (int) – see ShardedTensor

  • flattened_range (None) – must be None when using this constructor

  • init_kwargs – passed to ShardedTensor.init

init_data(
device: Union[str, torch.device],
init_fn=torch.empty,
)#

Initialize the tensor data of this ShardedTensor.

Only called if data attribute is None.

Parameters:
  • device (Union[str, torch.device]) – device to place the tensor on

  • init_fn (Callable, optional) – function to use to initialize the tensor. Defaults to torch.empty.

narrow(
dim: int,
start: int,
length: int,
) List[core.dist_checkpointing.mapping.ShardedTensor]#

This is an analogue of torch.narrow for ShardedTensors.

Narrowing assumes that we narrow a local tensor on each rank. This has consequences on local_shape, global_shape, global_offset, etc.

Parameters:
  • dim (int) – dimension to narrow. Doesn’t include prepended axes.

  • start (int) – start element

  • length (int) – length of the slice

Returns:

narrowed ShardedTensors. For non-flat tensors, the list will always have 1 element. For flat ShardedTensors the number of elements varies depending on dim and on overlap, because flat tensors must be contiguous. In particular the list can be empty.

Return type:

List[ShardedTensor]

core.dist_checkpointing.mapping.is_main_replica(replica_id: core.dist_checkpointing.mapping.ReplicaId)#

Checks if given replica_id is considered as main.

“Main” replica is:

  • integer 0

  • or an iterable with all 0 elements

It is the application responsibility to set correct replicas for sharded tensors.

Parameters:

replica_id (Union[int, Tuple[int, ...]]) – replica id

Returns:

True for a “main” replica

Return type:

(bool)

class core.dist_checkpointing.mapping.LocalNonpersistentObject(obj)#

Object that should not be stored in a checkpoint, but restored locally.

Wrapping any object inside the state dict with LocalNonpersistentObject will result in:

  • during saving, this object will not be stored in the checkpoint

  • during loading, a local version of this object will be placed in a state dict

Initialization

unwrap()#

Returns the original object.

class core.dist_checkpointing.mapping.ShardedObject#

Bases: core.dist_checkpointing.mapping.ShardedBase

Represents a mapping between a local object and a global object.

Global object is assumed to consist of many local objects distributed between different processes.

NOTE: Contrary to ShardedTensor, it’s impossible to change global object sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor with atomic arbitrary typed elements.

Parameters:
  • key – unique identifier of a global tensor

  • data – local object data. Can be None only for consistency validation

  • global_shape – global object shape

  • global_offset – offset of a local object in a global object, specified in number of shards

  • replica_id – indicates local object replication wrt. local objects in different processes

key: str#

None

data: object#

None

global_shape: Tuple[int, ...]#

None

global_offset: Tuple[int, ...]#

None

replica_id: core.dist_checkpointing.mapping.ReplicaId#

0

__post_init__()#
validate_metadata_integrity()#
without_data()#
property unique_key#

returns a unique key for this object

__str__()#
classmethod empty_from_unique_key(
unique_key,
replica_id: core.dist_checkpointing.mapping.ReplicaId = 0,
) core.dist_checkpointing.mapping.ShardedObject#

Instantiates a ShardedObject from a unique key.

Parameters:
  • unique_key – a string of the form /shard_<global_offset>_<global_shape>

  • replica_id – indicates local object replication wrt. local objects in different processes

Returns:

a ShardedObject with data=None

core.dist_checkpointing.mapping.FactoryBuildFn#

None

core.dist_checkpointing.mapping.FactoryMergeFn#

None

class core.dist_checkpointing.mapping.ShardedTensorFactory#

Bases: core.dist_checkpointing.mapping.ShardedBase

Allows to apply transformations to tensors before/after serialization.

The essence of those transformations is that they can be applied to optimizer states the same way they are applied to the model params. The ultimate state dict with sharded tensors must depend functionally on build_fn arguments (key, data, replica_id, flattened_range), which will be provided by the optimizer.

Builder creates a sub-state-dict out of a tensor before saving, and merger merges the corresponding state dict after loading.

Parameters:
  • key (str) – unique identifier of the factory

  • data (torch.Tensor) – original model parameter that will be further transformed by this factory

  • build_fn (callable) – function that transforms the original tensor to a sharded state dict

  • merge_fn (callable) – function that transforms loaded subtree back into a single tensor (inverse of build_fn)

  • replica_id (ReplicaId) – indicates factory replication wrt. factories in different processes

  • flattened_range (slice, optional) – indicates additional flattening applied to the ShardedTensors produced by the factory

key: str#

None

data: torch.Tensor#

None

build_fn: core.dist_checkpointing.mapping.FactoryBuildFn#

None

merge_fn: core.dist_checkpointing.mapping.FactoryMergeFn#

None

replica_id: core.dist_checkpointing.mapping.ReplicaId#

0

flattened_range: Optional[slice]#

None

build()#

Builds a ShardedStateDict from the original tensor

validate_metadata_integrity()#

No reasonable checks can be applied

without_data()#
core.dist_checkpointing.mapping.apply_factories(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
)#

Turn ShardedTensorFactories into ShardedTensors in-place.

Parameters:

sharded_state_dict (ShardedStateDict) – state dict possibly containing ShardedTensorFactory objects

Returns:

state dict is modified in place

Return type:

None

core.dist_checkpointing.mapping.apply_factory_merges(
x1: core.dist_checkpointing.mapping.StateDict,
x2: core.dist_checkpointing.mapping.ShardedStateDict,
key: Tuple[str, ...] = (),
) core.dist_checkpointing.mapping.StateDict#

Apply merges defined by ShardedTensorFactories in-place.

Parameters:
  • x1 (StateDict) – state dict loaded from the checkpoint

  • x2 (ShardedStateDict) – subset of x1 (in terms of dict keys) with ShardedTensorFactory as (possibly nested) values that define how to merge objects from the x1 state dict

  • key (Tuple[str, ...]) – current key in a recursive call. Used only for reporting meaningful errors

Returns:

x1 modified in-place

Return type:

StateDict