core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor#

Module Contents#

Functions#

gather_and_compute_chunk_metadata

Gather chunk metadata for a DTensor across all ranks and compute the offsets and sizes of each chunk. This is necessary for handling uneven sharding in distributed tensors.

update_uneven_dtensor_chunk_metadata

Update the DTensor’s chunk metadata to handle uneven sharding. This function modifies the DTensor in-place to include chunk metadata and write items closures for saving and loading.

validate_uneven_dtensor

Validates the chunk metadata of an uneven DTensor to ensure correctness and boundary coverage.

filter_unflattened_state_dict

Recursively traverses an unflattened state_dict and collects keys of items that meet the visit_condition. The keys are returned as lists of strings representing the path to each item in the state_dict.

get_unflattened_state_dict

Get a value from an unflattened state_dict at the specified key chain.

preprocess_state_dict_for_uneven_dtensor

Preprocess the state_dict to prepare it for saving or loading unevenly sharded DTensors. This function modifies the DTensors in the state_dict to include chunk metadata and write items closures.

uneven_dtensor_to_full_tensor

Gather a DTensor with potentially uneven sharding across ranks into a full tensor.

redistribute_uneven_dtensor_to_replicated

Redistribute an unevenly sharded DTensor to a fully replicated DTensor.

gather_uneven_dtensor_to_full_tensor

Deprecated: use redistribute_uneven_dtensor_to_replicated instead.

_intersection

_offset_slice

split_dtensor

Splits a DTensor into smaller DTensors along a specified dimension.

API#

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.gather_and_compute_chunk_metadata(
dtensor: torch.distributed._tensor.DTensor,
) torch.distributed.checkpoint.metadata.ChunkStorageMetadata#

Gather chunk metadata for a DTensor across all ranks and compute the offsets and sizes of each chunk. This is necessary for handling uneven sharding in distributed tensors.

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.update_uneven_dtensor_chunk_metadata(
dtensor: torch.distributed._tensor.DTensor,
) dict#

Update the DTensor’s chunk metadata to handle uneven sharding. This function modifies the DTensor in-place to include chunk metadata and write items closures for saving and loading.

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.validate_uneven_dtensor(
dtensor: torch.distributed._tensor.DTensor,
) None#

Validates the chunk metadata of an uneven DTensor to ensure correctness and boundary coverage.

Notes:

  • gather_and_compute_chunk_metadata will ensure that all chunks do not overlap.

This function performs the following checks:

  • All chunk offsets and sizes are within the tensor shape bounds.

  • All boundaries of each dimension are actually covered by shard placements.

Parameters:

dtensor (DTensor) – The distributed tensor to validate.

Raises:

AssertionError – If any chunk falls out of bounds or not all boundaries are touched.

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.filter_unflattened_state_dict(
state_dict,
key_chain=[],
visit_condition=lambda x: ...,
)#

Recursively traverses an unflattened state_dict and collects keys of items that meet the visit_condition. The keys are returned as lists of strings representing the path to each item in the state_dict.

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.get_unflattened_state_dict(state_dict, key_chain=[])#

Get a value from an unflattened state_dict at the specified key chain.

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.preprocess_state_dict_for_uneven_dtensor(state_dict: dict) dict#

Preprocess the state_dict to prepare it for saving or loading unevenly sharded DTensors. This function modifies the DTensors in the state_dict to include chunk metadata and write items closures.

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.uneven_dtensor_to_full_tensor(
dtensor: torch.distributed._tensor.DTensor,
) torch.Tensor#

Gather a DTensor with potentially uneven sharding across ranks into a full tensor.

This function handles DTensors with uneven shards (where different ranks may have different-sized chunks) by gathering chunk metadata and local tensors across all ranks, then reconstructing the complete tensor.

Parameters:

dtensor (DTensor) – The distributed tensor to gather. Must have chunk metadata available (either pre-existing or will be computed).

Returns:

The fully reconstructed tensor with shape matching the original DTensor’s global shape.

Return type:

torch.Tensor

Raises:
  • TypeError – If input is not a DTensor.

  • ValueError – If chunk metadata is malformed (expected exactly one chunk per rank).

  • AssertionError – If an unexpected placement type is encountered after processing Shard placements.

.. note::

  • This function performs collective operations (all_gather_object, all_gather) across the device mesh, requiring synchronization across ranks.

  • Works with Shard and _StridedShard placements, and expects Replicate placements for non-sharded dimensions.

  • The function modifies the DTensor in-place by adding chunk metadata if missing.

.. rubric:: Example

mesh = DeviceMesh(“cuda”, [0, 1, 2, 3])

Create a DTensor with uneven sharding

dtensor = DTensor(…, placements=[Shard(0)]) full_tensor = gather_uneven_dtensor_to_full_tensor(dtensor) assert full_tensor.shape == dtensor.shape

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.redistribute_uneven_dtensor_to_replicated(
dtensor: torch.distributed._tensor.DTensor,
) torch.distributed._tensor.DTensor#

Redistribute an unevenly sharded DTensor to a fully replicated DTensor.

This function first gathers the unevenly sharded DTensor into a full tensor and then redistributes it as a replicated DTensor across all ranks.

Parameters:

dtensor (DTensor) – The unevenly sharded DTensor to redistribute.

Returns:

A replicated DTensor with the same data as the input DTensor.

Return type:

DTensor

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.gather_uneven_dtensor_to_full_tensor(
dtensor: torch.distributed._tensor.DTensor,
) torch.distributed._tensor.DTensor#

Deprecated: use redistribute_uneven_dtensor_to_replicated instead.

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor._intersection(s1, s2)#
core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor._offset_slice(s, offset)#
core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.split_dtensor(
dtensor: torch.distributed._tensor.DTensor,
split_size_or_sections: Union[int, List[int]],
dim: int = 0,
update_uneven_dtensor_chunk_meta: bool = False,
) Iterable[torch.distributed._tensor.DTensor]#

Splits a DTensor into smaller DTensors along a specified dimension.

This function manages uneven sharding by accurately assigning chunk metadata for each split. Unlike the native PyTorch DTensor split functionality, it does not redistribute Replicate placements, which helps avoid Out-Of-Memory (OOM) issues.

Parameters:
  • dtensor (DTensor) – The DTensor to split.

  • split_size_or_sections (int or list of int) – If int, defines the size of each chunk. If a list, specifies the sizes of each chunk in order.

  • dim (int, optional) – The axis along which to split. Default is 0.

  • update_uneven_dtensor_chunk_meta (bool, optional) – Whether to update chunk metadata for each resulting DTensor. Default is False.

Yields:

DTensor – Sub-DTensor resulting from the split, maintaining correct metadata.

.. rubric:: Example

for chunk in split_dtensor(dt, 3, dim=1): … print(chunk)