core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor#

Module Contents#

Functions#

gather_and_compute_chunk_metadata

Gather chunk metadata for a DTensor across all ranks and compute the offsets and sizes of each chunk. This is necessary for handling uneven sharding in distributed tensors.

update_uneven_dtensor_chunk_metadata

Update the DTensor’s chunk metadata to handle uneven sharding. This function modifies the DTensor in-place to include chunk metadata and write items closures for saving and loading.

validate_uneven_dtensor

Validates the chunk metadata of an uneven DTensor to ensure correctness and boundary coverage.

filter_unflattened_state_dict

Recursively traverses an unflattened state_dict and collects keys of items that meet the visit_condition. The keys are returned as lists of strings representing the path to each item in the state_dict.

get_unflattened_state_dict

Get a value from an unflattened state_dict at the specified key chain.

preprocess_state_dict_for_uneven_dtensor

Preprocess the state_dict to prepare it for saving or loading unevenly sharded DTensors. This function modifies the DTensors in the state_dict to include chunk metadata and write items closures.

gather_uneven_dtensor_to_full_tensor

Gather an unevenly sharded DTensor distributed across multiple ranks, reconstructing the full (unsharded) tensor on each rank.

_assemble_full_tensor_from_uneven_chunks

Assemble the full tensor from unevenly sized chunks gathered from all ranks.

_intersection

_offset_slice

split_dtensor

Splits a DTensor into smaller DTensors along a specified dimension.

API#

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.gather_and_compute_chunk_metadata(
dtensor: torch.distributed._tensor.DTensor,
) torch.distributed.checkpoint.metadata.ChunkStorageMetadata#

Gather chunk metadata for a DTensor across all ranks and compute the offsets and sizes of each chunk. This is necessary for handling uneven sharding in distributed tensors.

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.update_uneven_dtensor_chunk_metadata(
dtensor: torch.distributed._tensor.DTensor,
) dict#

Update the DTensor’s chunk metadata to handle uneven sharding. This function modifies the DTensor in-place to include chunk metadata and write items closures for saving and loading.

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.validate_uneven_dtensor(
dtensor: torch.distributed._tensor.DTensor,
) None#

Validates the chunk metadata of an uneven DTensor to ensure correctness and boundary coverage.

Notes:

  • gather_and_compute_chunk_metadata will ensure that all chunks do not overlap.

This function performs the following checks:

  • All chunk offsets and sizes are within the tensor shape bounds.

  • All boundaries of each dimension are actually covered by shard placements.

Parameters:

dtensor (DTensor) – The distributed tensor to validate.

Raises:

AssertionError – If any chunk falls out of bounds or not all boundaries are touched.

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.filter_unflattened_state_dict(
state_dict,
key_chain=[],
visit_condition=lambda x: ...,
)#

Recursively traverses an unflattened state_dict and collects keys of items that meet the visit_condition. The keys are returned as lists of strings representing the path to each item in the state_dict.

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.get_unflattened_state_dict(state_dict, key_chain=[])#

Get a value from an unflattened state_dict at the specified key chain.

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.preprocess_state_dict_for_uneven_dtensor(state_dict: dict) dict#

Preprocess the state_dict to prepare it for saving or loading unevenly sharded DTensors. This function modifies the DTensors in the state_dict to include chunk metadata and write items closures.

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.gather_uneven_dtensor_to_full_tensor(
dtensor: torch.distributed._tensor.DTensor,
target_device: Optional[torch.device] = None,
) torch.distributed._tensor.DTensor#

Gather an unevenly sharded DTensor distributed across multiple ranks, reconstructing the full (unsharded) tensor on each rank.

This function handles uneven chunk sizes and offsets by collecting chunk metadata from all ranks, performing all-gather operations, and assembling the full tensor accordingly. The returned tensor is fully replicated across the given device mesh.

Parameters:
  • dtensor (DTensor) – Distributed tensor with uneven sharding across ranks.

  • target_device (Optional[torch.device]) – If specified, move the resulting full tensor to this device. Otherwise, use the original device.

Returns:

Fully replicated DTensor representing the reconstructed full tensor.

Return type:

DTensor

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor._assemble_full_tensor_from_uneven_chunks(
dtensor: torch.distributed._tensor.DTensor,
all_chunk_info: List[dict],
process_group: torch.distributed.ProcessGroup,
target_device: Optional[torch.device],
) torch.distributed._tensor.DTensor#

Assemble the full tensor from unevenly sized chunks gathered from all ranks.

Parameters:
  • dtensor (DTensor) – The original distributed tensor.

  • all_chunk_info (List[Dict]) – List of shard info dicts from all ranks, including shapes and offsets.

  • process_group – Process group for collective communication.

  • target_device – Optional device to move the final full tensor onto.

Returns:

Fully replicated tensor constructed by placing chunks at the appropriate offsets.

Return type:

DTensor

core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor._intersection(s1, s2)#
core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor._offset_slice(s, offset)#
core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.split_dtensor(
dtensor: torch.distributed._tensor.DTensor,
split_size_or_sections: Union[int, List[int]],
dim: int = 0,
update_uneven_dtensor_chunk_meta: bool = False,
) Iterable[torch.distributed._tensor.DTensor]#

Splits a DTensor into smaller DTensors along a specified dimension.

This function manages uneven sharding by accurately assigning chunk metadata for each split. Unlike the native PyTorch DTensor split functionality, it does not redistribute Replicate placements, which helps avoid Out-Of-Memory (OOM) issues.

Parameters:
  • dtensor (DTensor) – The DTensor to split.

  • split_size_or_sections (int or list of int) – If int, defines the size of each chunk. If a list, specifies the sizes of each chunk in order.

  • dim (int, optional) – The axis along which to split. Default is 0.

  • update_uneven_dtensor_chunk_meta (bool, optional) – Whether to update chunk metadata for each resulting DTensor. Default is False.

Yields:

DTensor – Sub-DTensor resulting from the split, maintaining correct metadata.

.. rubric:: Example

for chunk in split_dtensor(dt, 3, dim=1): … print(chunk)