core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor#
Module Contents#
Functions#
Gather chunk metadata for a DTensor across all ranks and compute the offsets and sizes of each chunk. This is necessary for handling uneven sharding in distributed tensors. |
|
Update the DTensor’s chunk metadata to handle uneven sharding. This function modifies the DTensor in-place to include chunk metadata and write items closures for saving and loading. |
|
Validates the chunk metadata of an uneven DTensor to ensure correctness and boundary coverage. |
|
Recursively traverses an unflattened state_dict and collects keys of items that meet the visit_condition. The keys are returned as lists of strings representing the path to each item in the state_dict. |
|
Get a value from an unflattened state_dict at the specified key chain. |
|
Preprocess the state_dict to prepare it for saving or loading unevenly sharded DTensors. This function modifies the DTensors in the state_dict to include chunk metadata and write items closures. |
|
Gather a DTensor with potentially uneven sharding across ranks into a full tensor. |
|
Redistribute an unevenly sharded DTensor to a fully replicated DTensor. |
|
Deprecated: use |
|
Splits a DTensor into smaller DTensors along a specified dimension. |
API#
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.gather_and_compute_chunk_metadata(
- dtensor: torch.distributed._tensor.DTensor,
Gather chunk metadata for a DTensor across all ranks and compute the offsets and sizes of each chunk. This is necessary for handling uneven sharding in distributed tensors.
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.update_uneven_dtensor_chunk_metadata(
- dtensor: torch.distributed._tensor.DTensor,
Update the DTensor’s chunk metadata to handle uneven sharding. This function modifies the DTensor in-place to include chunk metadata and write items closures for saving and loading.
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.validate_uneven_dtensor(
- dtensor: torch.distributed._tensor.DTensor,
Validates the chunk metadata of an uneven DTensor to ensure correctness and boundary coverage.
Notes:
gather_and_compute_chunk_metadatawill ensure that all chunks do not overlap.
This function performs the following checks:
All chunk offsets and sizes are within the tensor shape bounds.
All boundaries of each dimension are actually covered by shard placements.
- Parameters:
dtensor (DTensor) – The distributed tensor to validate.
- Raises:
AssertionError – If any chunk falls out of bounds or not all boundaries are touched.
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.filter_unflattened_state_dict(
- state_dict,
- key_chain=[],
- visit_condition=lambda x: ...,
Recursively traverses an unflattened state_dict and collects keys of items that meet the visit_condition. The keys are returned as lists of strings representing the path to each item in the state_dict.
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.get_unflattened_state_dict(state_dict, key_chain=[])#
Get a value from an unflattened state_dict at the specified key chain.
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.preprocess_state_dict_for_uneven_dtensor(state_dict: dict) dict#
Preprocess the state_dict to prepare it for saving or loading unevenly sharded DTensors. This function modifies the DTensors in the state_dict to include chunk metadata and write items closures.
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.uneven_dtensor_to_full_tensor(
- dtensor: torch.distributed._tensor.DTensor,
Gather a DTensor with potentially uneven sharding across ranks into a full tensor.
This function handles DTensors with uneven shards (where different ranks may have different-sized chunks) by gathering chunk metadata and local tensors across all ranks, then reconstructing the complete tensor.
- Parameters:
dtensor (DTensor) – The distributed tensor to gather. Must have chunk metadata available (either pre-existing or will be computed).
- Returns:
The fully reconstructed tensor with shape matching the original DTensor’s global shape.
- Return type:
torch.Tensor
- Raises:
TypeError – If input is not a DTensor.
ValueError – If chunk metadata is malformed (expected exactly one chunk per rank).
AssertionError – If an unexpected placement type is encountered after processing Shard placements.
.. note::
This function performs collective operations (all_gather_object, all_gather) across the device mesh, requiring synchronization across ranks.
Works with Shard and _StridedShard placements, and expects Replicate placements for non-sharded dimensions.
The function modifies the DTensor in-place by adding chunk metadata if missing.
.. rubric:: Example
mesh = DeviceMesh(“cuda”, [0, 1, 2, 3])
Create a DTensor with uneven sharding
dtensor = DTensor(…, placements=[Shard(0)]) full_tensor = gather_uneven_dtensor_to_full_tensor(dtensor) assert full_tensor.shape == dtensor.shape
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.redistribute_uneven_dtensor_to_replicated(
- dtensor: torch.distributed._tensor.DTensor,
Redistribute an unevenly sharded DTensor to a fully replicated DTensor.
This function first gathers the unevenly sharded DTensor into a full tensor and then redistributes it as a replicated DTensor across all ranks.
- Parameters:
dtensor (DTensor) – The unevenly sharded DTensor to redistribute.
- Returns:
A replicated DTensor with the same data as the input DTensor.
- Return type:
DTensor
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.gather_uneven_dtensor_to_full_tensor(
- dtensor: torch.distributed._tensor.DTensor,
Deprecated: use
redistribute_uneven_dtensor_to_replicatedinstead.
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor._intersection(s1, s2)#
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor._offset_slice(s, offset)#
- core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor.split_dtensor(
- dtensor: torch.distributed._tensor.DTensor,
- split_size_or_sections: Union[int, List[int]],
- dim: int = 0,
- update_uneven_dtensor_chunk_meta: bool = False,
Splits a DTensor into smaller DTensors along a specified dimension.
This function manages uneven sharding by accurately assigning chunk metadata for each split. Unlike the native PyTorch DTensor split functionality, it does not redistribute
Replicateplacements, which helps avoid Out-Of-Memory (OOM) issues.- Parameters:
dtensor (DTensor) – The DTensor to split.
split_size_or_sections (int or list of int) – If int, defines the size of each chunk. If a list, specifies the sizes of each chunk in order.
dim (int, optional) – The axis along which to split. Default is 0.
update_uneven_dtensor_chunk_meta (bool, optional) – Whether to update chunk metadata for each resulting DTensor. Default is False.
- Yields:
DTensor – Sub-DTensor resulting from the split, maintaining correct metadata.
.. rubric:: Example
for chunk in split_dtensor(dt, 3, dim=1): … print(chunk)