core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor#
Module Contents#
Functions#
Gather chunk metadata for a DTensor across all ranks and compute the offsets and sizes of each chunk. This is necessary for handling uneven sharding in distributed tensors. |
|
Update the DTensor’s chunk metadata to handle uneven sharding. This function modifies the DTensor in-place to include chunk metadata and write items closures for saving and loading. |
|
Validates the chunk metadata of an uneven DTensor to ensure correctness and boundary coverage. |
|
Recursively traverses an unflattened state_dict and collects keys of items that meet the visit_condition. The keys are returned as lists of strings representing the path to each item in the state_dict. |
|
Get a value from an unflattened state_dict at the specified key chain. |
|
Preprocess the state_dict to prepare it for saving or loading unevenly sharded DTensors. This function modifies the DTensors in the state_dict to include chunk metadata and write items closures. |
|
Gather an unevenly sharded DTensor distributed across multiple ranks, reconstructing the full (unsharded) tensor on each rank. |
|
Assemble the full tensor from unevenly sized chunks gathered from all ranks. |
|
Splits a DTensor into smaller DTensors along a specified dimension. |
API#
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.gather_and_compute_chunk_metadata(
- dtensor: torch.distributed._tensor.DTensor,
Gather chunk metadata for a DTensor across all ranks and compute the offsets and sizes of each chunk. This is necessary for handling uneven sharding in distributed tensors.
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.update_uneven_dtensor_chunk_metadata(
- dtensor: torch.distributed._tensor.DTensor,
Update the DTensor’s chunk metadata to handle uneven sharding. This function modifies the DTensor in-place to include chunk metadata and write items closures for saving and loading.
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.validate_uneven_dtensor(
- dtensor: torch.distributed._tensor.DTensor,
Validates the chunk metadata of an uneven DTensor to ensure correctness and boundary coverage.
Notes:
gather_and_compute_chunk_metadatawill ensure that all chunks do not overlap.
This function performs the following checks:
All chunk offsets and sizes are within the tensor shape bounds.
All boundaries of each dimension are actually covered by shard placements.
- Parameters:
dtensor (DTensor) – The distributed tensor to validate.
- Raises:
AssertionError – If any chunk falls out of bounds or not all boundaries are touched.
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.filter_unflattened_state_dict(
- state_dict,
- key_chain=[],
- visit_condition=lambda x: ...,
Recursively traverses an unflattened state_dict and collects keys of items that meet the visit_condition. The keys are returned as lists of strings representing the path to each item in the state_dict.
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.get_unflattened_state_dict(state_dict, key_chain=[])#
Get a value from an unflattened state_dict at the specified key chain.
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.preprocess_state_dict_for_uneven_dtensor(state_dict: dict) dict#
Preprocess the state_dict to prepare it for saving or loading unevenly sharded DTensors. This function modifies the DTensors in the state_dict to include chunk metadata and write items closures.
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.gather_uneven_dtensor_to_full_tensor(
- dtensor: torch.distributed._tensor.DTensor,
- target_device: Optional[torch.device] = None,
Gather an unevenly sharded DTensor distributed across multiple ranks, reconstructing the full (unsharded) tensor on each rank.
This function handles uneven chunk sizes and offsets by collecting chunk metadata from all ranks, performing all-gather operations, and assembling the full tensor accordingly. The returned tensor is fully replicated across the given device mesh.
- Parameters:
dtensor (DTensor) – Distributed tensor with uneven sharding across ranks.
target_device (Optional[torch.device]) – If specified, move the resulting full tensor to this device. Otherwise, use the original device.
- Returns:
Fully replicated DTensor representing the reconstructed full tensor.
- Return type:
DTensor
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor._assemble_full_tensor_from_uneven_chunks(
- dtensor: torch.distributed._tensor.DTensor,
- all_chunk_info: List[dict],
- process_group: torch.distributed.ProcessGroup,
- target_device: Optional[torch.device],
Assemble the full tensor from unevenly sized chunks gathered from all ranks.
- Parameters:
dtensor (DTensor) – The original distributed tensor.
all_chunk_info (List[Dict]) – List of shard info dicts from all ranks, including shapes and offsets.
process_group – Process group for collective communication.
target_device – Optional device to move the final full tensor onto.
- Returns:
Fully replicated tensor constructed by placing chunks at the appropriate offsets.
- Return type:
DTensor
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor._intersection(s1, s2)#
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor._offset_slice(s, offset)#
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.split_dtensor(
- dtensor: torch.distributed._tensor.DTensor,
- split_size_or_sections: Union[int, List[int]],
- dim: int = 0,
- update_uneven_dtensor_chunk_meta: bool = False,
Splits a DTensor into smaller DTensors along a specified dimension.
This function manages uneven sharding by accurately assigning chunk metadata for each split. Unlike the native PyTorch DTensor split functionality, it does not redistribute
Replicateplacements, which helps avoid Out-Of-Memory (OOM) issues.- Parameters:
dtensor (DTensor) – The DTensor to split.
split_size_or_sections (int or list of int) – If int, defines the size of each chunk. If a list, specifies the sizes of each chunk in order.
dim (int, optional) – The axis along which to split. Default is 0.
update_uneven_dtensor_chunk_meta (bool, optional) – Whether to update chunk metadata for each resulting DTensor. Default is False.
- Yields:
DTensor – Sub-DTensor resulting from the split, maintaining correct metadata.
.. rubric:: Example
for chunk in split_dtensor(dt, 3, dim=1): … print(chunk)