PhysicsNeMo domain_parallel
In scientific AI applications, the parallelization techniques to enable state of the art
models are different from those used in training large language models. PhysicsNeMo
introduces a new parallelization primitive called a ShardTensor that is designed for
large-input AI applications to enable domain parallelization.
ShardTensor provides a distributed tensor implementation that supports uneven sharding across devices.
It builds on PyTorch’s DTensor while adding flexibility for cases where different ranks may have
different local tensor sizes.
ShardTensor
-
class physicsnemo.domain_parallel.ShardTensor(
- local_tensor: Tensor,
- spec: ShardTensorSpec,
- *,
- requires_grad: bool,
)[source]
Bases: DTensor
A distributed tensor class with support for uneven data sharding.
Similar to PyTorch’s native DTensor but with more flexibility for
uneven data sharding. Leverages a very similar API to DTensor
(identical where possible) but deliberately tweaks routines to avoid
implicit assumptions about tensor sharding.
The key differences from DTensor are:
Supports uneven sharding where different ranks can have different
local tensor sizes
Tracks and propagates shard size information across operations
Handles redistribution of unevenly sharded tensors
Provides custom collective operations optimized for uneven sharding
Like DTensor, operations are dispatched through PyTorch’s dispatcher
system. Most operations work by:
Converting inputs to local tensors
Performing the operation locally
Constructing a new ShardTensor with appropriate sharding spec
Handling any needed communication between ranks
The class provides methods for:
Converting to/from local tensors
Redistributing between different sharding schemes
Performing collective operations like all_gather and reduce_scatter
Basic tensor operations that maintain sharding information
-
_local_tensor
The local tensor data on this rank.
- Type:
torch.Tensor
-
_spec
The specification defining sharding scheme and metadata.
- Type:
ShardTensorSpec
-
backward(*args, **kwargs)[source]
Perform backward pass for ShardTensor.
Handles the redistribution of the tensor to resolve any partial
placements before calling backward on the local tensor.
- Parameters:
-
-
classmethod from_dtensor(
- dtensor: DTensor,
) → ShardTensor[source]
Convert a DTensor to a ShardTensor.
Assumes the DTensor is properly constructed. Since DTensor is locked
to sharding a tensor according to chunk format, the sharding sizes
can be inferred with no communication.
If the DTensor is a non-leaf (has a grad_fn), the autograd graph
is preserved via _PromoteDTensorToShardTensor.
- Parameters:
dtensor (DTensor) – DTensor to convert.
- Returns:
Equivalent ShardTensor with the same local tensor and inferred spec.
- Return type:
ShardTensor
-
static from_local(
- local_tensor: Tensor,
- device_mesh: DeviceMesh | None = None,
- placements: Sequence[Placement] | None = None,
- sharding_shapes: str | dict[int, list[tuple[int, ...]]] = 'infer',
) → ShardTensor[source]
Generate a new ShardTensor from local torch tensors.
Uses device mesh and placements to infer global tensor properties.
No restriction is made on forcing tensors to have equal shapes locally.
Instead, the requirement is that tensor shapes could be concatenated
into a single tensor according to the placements.
- Parameters:
local_tensor (torch.Tensor) – Local chunk of tensor. All participating tensors must be of the
same rank and concatenatable across the mesh dimensions.
device_mesh (Optional[DeviceMesh], optional) – Target device mesh. If not specified, will use the current mesh.
placements (Optional[Sequence[Placement]], optional) – Target placements. Must have same number of elements as
device_mesh.ndim.
sharding_shapes (Union[str, Dict[int, List[Tuple[int, ...]]]], default="infer") –
Controls how shard tensor spec is generated:
"chunk": Use torch.chunk shapes to infer shapes from
global shape (no communication).
"infer": Use collective communication to infer shapes from
mesh neighbors.
Manual dict mapping mesh dim to list of shard shapes: Use
provided shapes. Must pass on each rank.
- Returns:
A new ShardTensor instance.
- Return type:
ShardTensor
-
full_tensor(
- *,
- grad_placements: Sequence[Placement] | None = None,
) → Tensor[source]
Gather the full tensor from all ranks.
Redistributes to Replicate placement on all mesh dimensions and
returns the local tensor.
- Parameters:
grad_placements (Optional[Sequence[Placement]], optional) – Future layout of gradients. If provided, gradients will be
constructed with this placement scheme during backward pass.
- Returns:
The full gathered tensor, identical on all ranks.
- Return type:
torch.Tensor
-
offsets(mesh_dim: int | None = None) → list[int] | int[source]
Get offsets of shards along a mesh dimension.
- Parameters:
mesh_dim (Optional[int], optional) – Mesh dimension to get offsets for. If None, returns all offsets.
- Returns:
List of offsets for shards along all dimensions, or single offset
if mesh_dim is specified.
- Return type:
Union[List[int], int]
-
classmethod patches_enabled() → bool[source]
Check whether patches are enabled for this class.
- Returns:
True if shard patches are enabled, False otherwise.
Default is False until a ShardTensor is constructed.
- Return type:
bool
-
redistribute(
- device_mesh: DeviceMesh | None = None,
- placements: Sequence[Placement] | None = None,
- *,
- async_op: bool = False,
) → ShardTensor[source]
Redistribute tensor across device mesh with new placement scheme.
Like DTensor.redistribute but uses custom layer for shard
redistribution that supports uneven sharding.
- Parameters:
device_mesh (Optional[DeviceMesh], optional) – Target device mesh. Uses current mesh if None.
placements (Optional[Sequence[Placement]], optional) – Target placement scheme. Required.
async_op (bool, default=False) – Whether to run asynchronously.
- Returns:
Redistributed ShardTensor with new placement scheme.
- Return type:
ShardTensor
- Raises:
RuntimeError – If placements is not specified or contains invalid placements
(e.g., Partial placements or negative shard dimensions).
-
classmethod register_dispatch_handler(
- op: OpOverload,
- handler: Callable,
) → None[source]
Register a handler for a specific PyTorch operator in the dispatch system.
- Parameters:
-
-
classmethod register_function_handler(
- func: Callable,
- handler: Callable,
) → None[source]
Register a handler for a Python-level function or method.
- Parameters:
-
-
classmethod register_named_function_handler(
- func_name: str,
- handler: Callable,
) → None[source]
Register a named function registered via torch.library.custom_op.
- Parameters:
func_name (str) – The string name of the custom op (e.g., "module.function_name.default").
handler (Callable) – The handler function to call when the function is invoked.
-
to_local(
- *,
- grad_placements: Sequence[Placement] | None = None,
) → Tensor[source]
Get local tensor from this ShardTensor.
- Parameters:
grad_placements (Optional[Sequence[Placement]], optional) – Future layout of gradients. If provided, gradients will be
constructed with this placement scheme during backward pass.
- Returns:
Local tensor. Shape may vary between ranks for sharded tensors.
- Return type:
torch.Tensor
Utility Functions
-
physicsnemo.domain_parallel.scatter_tensor(
- tensor: Tensor,
- global_src: int,
- mesh: DeviceMesh,
- placements: tuple[Placement, ...],
- global_shape: Size | None = None,
- dtype: dtype | None = None,
- requires_grad: bool = False,
) → ShardTensor[source]
Distribute a tensor from source rank across devices on the mesh.
Takes a tensor that exists on a single source rank and distributes it
across a device mesh according to the specified placement scheme. For
multi-dimensional meshes, it performs a flattened scatter operation
before constructing the sharded tensor.
- Parameters:
tensor (torch.Tensor) – The tensor to distribute. Must exist on source rank; can be None
on other ranks.
global_src (int) – Global rank ID of the source process.
mesh (DeviceMesh) – Device mesh defining the process topology.
placements (Tuple[Placement, ...]) – Tuple of placement specifications defining how to distribute the tensor.
global_shape (Optional[torch.Size], optional) – Global shape of the tensor. If None, will be broadcast from source.
dtype (Optional[torch.dtype], optional) – Data type of the tensor. If None, will be broadcast from source.
requires_grad (bool, default=False) – Whether the resulting ShardTensor requires gradients.
- Returns:
The distributed tensor with specified placements.
- Return type:
ShardTensor
- Raises:
ValueError – If global_src is not an integer or not in the mesh.
For detailed information on ShardTensor and domain parallelism, please refer to the Domain Parallelism tutorial.