PhysicsNeMo domain_parallel#

In scientific AI applications, the parallelization techniques to enable state of the art models are different from those used in training large language models. PhysicsNeMo introduces a new parallelization primitive called a ShardTensor that is designed for large-input AI applications to enable domain parallelization.

ShardTensor provides a distributed tensor implementation that supports uneven sharding across devices. It builds on PyTorch’s DTensor while adding flexibility for cases where different ranks may have different local tensor sizes.

ShardTensor#

class physicsnemo.domain_parallel.ShardTensor(
local_tensor: Tensor,
spec: ShardTensorSpec,
*,
requires_grad: bool,
)[source]#

Bases: DTensor

A distributed tensor class with support for uneven data sharding.

Similar to PyTorch’s native DTensor but with more flexibility for uneven data sharding. Leverages a very similar API to DTensor (identical where possible) but deliberately tweaks routines to avoid implicit assumptions about tensor sharding.

The key differences from DTensor are:

  • Supports uneven sharding where different ranks can have different local tensor sizes

  • Tracks and propagates shard size information across operations

  • Handles redistribution of unevenly sharded tensors

  • Provides custom collective operations optimized for uneven sharding

Like DTensor, operations are dispatched through PyTorch’s dispatcher system. Most operations work by:

  1. Converting inputs to local tensors

  2. Performing the operation locally

  3. Constructing a new ShardTensor with appropriate sharding spec

  4. Handling any needed communication between ranks

The class provides methods for:

  • Converting to/from local tensors

  • Redistributing between different sharding schemes

  • Performing collective operations like all_gather and reduce_scatter

  • Basic tensor operations that maintain sharding information

_local_tensor#

The local tensor data on this rank.

Type:

torch.Tensor

_spec#

The specification defining sharding scheme and metadata.

Type:

ShardTensorSpec

backward(*args, **kwargs)[source]#

Perform backward pass for ShardTensor.

Handles the redistribution of the tensor to resolve any partial placements before calling backward on the local tensor.

Parameters:
  • *args – Positional arguments passed to torch.Tensor.backward.

  • **kwargs – Keyword arguments passed to torch.Tensor.backward.

classmethod from_dtensor(
dtensor: DTensor,
) ShardTensor[source]#

Convert a DTensor to a ShardTensor.

Assumes the DTensor is properly constructed. Since DTensor is locked to sharding a tensor according to chunk format, the sharding sizes can be inferred with no communication.

If the DTensor is a non-leaf (has a grad_fn), the autograd graph is preserved via _PromoteDTensorToShardTensor.

Parameters:

dtensor (DTensor) – DTensor to convert.

Returns:

Equivalent ShardTensor with the same local tensor and inferred spec.

Return type:

ShardTensor

static from_local(
local_tensor: Tensor,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
sharding_shapes: str | dict[int, list[tuple[int, ...]]] = 'infer',
) ShardTensor[source]#

Generate a new ShardTensor from local torch tensors.

Uses device mesh and placements to infer global tensor properties. No restriction is made on forcing tensors to have equal shapes locally. Instead, the requirement is that tensor shapes could be concatenated into a single tensor according to the placements.

Parameters:
  • local_tensor (torch.Tensor) – Local chunk of tensor. All participating tensors must be of the same rank and concatenatable across the mesh dimensions.

  • device_mesh (Optional[DeviceMesh], optional) – Target device mesh. If not specified, will use the current mesh.

  • placements (Optional[Sequence[Placement]], optional) – Target placements. Must have same number of elements as device_mesh.ndim.

  • sharding_shapes (Union[str, Dict[int, List[Tuple[int, ...]]]], default="infer") –

    Controls how shard tensor spec is generated:

    • "chunk": Use torch.chunk shapes to infer shapes from global shape (no communication).

    • "infer": Use collective communication to infer shapes from mesh neighbors.

    • Manual dict mapping mesh dim to list of shard shapes: Use provided shapes. Must pass on each rank.

Returns:

A new ShardTensor instance.

Return type:

ShardTensor

full_tensor(
*,
grad_placements: Sequence[Placement] | None = None,
) Tensor[source]#

Gather the full tensor from all ranks.

Redistributes to Replicate placement on all mesh dimensions and returns the local tensor.

Parameters:

grad_placements (Optional[Sequence[Placement]], optional) – Future layout of gradients. If provided, gradients will be constructed with this placement scheme during backward pass.

Returns:

The full gathered tensor, identical on all ranks.

Return type:

torch.Tensor

offsets(mesh_dim: int | None = None) list[int] | int[source]#

Get offsets of shards along a mesh dimension.

Parameters:

mesh_dim (Optional[int], optional) – Mesh dimension to get offsets for. If None, returns all offsets.

Returns:

List of offsets for shards along all dimensions, or single offset if mesh_dim is specified.

Return type:

Union[List[int], int]

classmethod patches_enabled() bool[source]#

Check whether patches are enabled for this class.

Returns:

True if shard patches are enabled, False otherwise. Default is False until a ShardTensor is constructed.

Return type:

bool

redistribute(
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
*,
async_op: bool = False,
) ShardTensor[source]#

Redistribute tensor across device mesh with new placement scheme.

Like DTensor.redistribute but uses custom layer for shard redistribution that supports uneven sharding.

Parameters:
  • device_mesh (Optional[DeviceMesh], optional) – Target device mesh. Uses current mesh if None.

  • placements (Optional[Sequence[Placement]], optional) – Target placement scheme. Required.

  • async_op (bool, default=False) – Whether to run asynchronously.

Returns:

Redistributed ShardTensor with new placement scheme.

Return type:

ShardTensor

Raises:

RuntimeError – If placements is not specified or contains invalid placements (e.g., Partial placements or negative shard dimensions).

classmethod register_dispatch_handler(
op: OpOverload,
handler: Callable,
) None[source]#

Register a handler for a specific PyTorch operator in the dispatch system.

Parameters:
  • op (torch._ops.OpOverload) – The PyTorch operator to register a handler for.

  • handler (Callable) – The handler function to call when the operator is invoked.

classmethod register_function_handler(
func: Callable,
handler: Callable,
) None[source]#

Register a handler for a Python-level function or method.

Parameters:
  • func (Callable) – The Python function to register a handler for.

  • handler (Callable) – The handler function to call when the function is invoked.

classmethod register_named_function_handler(
func_name: str,
handler: Callable,
) None[source]#

Register a named function registered via torch.library.custom_op.

Parameters:
  • func_name (str) – The string name of the custom op (e.g., "module.function_name.default").

  • handler (Callable) – The handler function to call when the function is invoked.

to_local(
*,
grad_placements: Sequence[Placement] | None = None,
) Tensor[source]#

Get local tensor from this ShardTensor.

Parameters:

grad_placements (Optional[Sequence[Placement]], optional) – Future layout of gradients. If provided, gradients will be constructed with this placement scheme during backward pass.

Returns:

Local tensor. Shape may vary between ranks for sharded tensors.

Return type:

torch.Tensor

Utility Functions#

physicsnemo.domain_parallel.scatter_tensor(
tensor: Tensor,
global_src: int,
mesh: DeviceMesh,
placements: tuple[Placement, ...],
global_shape: Size | None = None,
dtype: dtype | None = None,
requires_grad: bool = False,
) ShardTensor[source]#

Distribute a tensor from source rank across devices on the mesh.

Takes a tensor that exists on a single source rank and distributes it across a device mesh according to the specified placement scheme. For multi-dimensional meshes, it performs a flattened scatter operation before constructing the sharded tensor.

Parameters:
  • tensor (torch.Tensor) – The tensor to distribute. Must exist on source rank; can be None on other ranks.

  • global_src (int) – Global rank ID of the source process.

  • mesh (DeviceMesh) – Device mesh defining the process topology.

  • placements (Tuple[Placement, ...]) – Tuple of placement specifications defining how to distribute the tensor.

  • global_shape (Optional[torch.Size], optional) – Global shape of the tensor. If None, will be broadcast from source.

  • dtype (Optional[torch.dtype], optional) – Data type of the tensor. If None, will be broadcast from source.

  • requires_grad (bool, default=False) – Whether the resulting ShardTensor requires gradients.

Returns:

The distributed tensor with specified placements.

Return type:

ShardTensor

Raises:

ValueError – If global_src is not an integer or not in the mesh.

For detailed information on ShardTensor and domain parallelism, please refer to the Domain Parallelism tutorial.