PhysicsNeMo ShardTensor#

In scientific AI applications, the parallelization techniques to enable state of the art models are different from those used in training large language models. PhysicsNeMo introduces a new parallelization primitive called a ShardTensor that is designed for large-input AI applications to enable domain parallelization.

ShardTensor provides a distributed tensor implementation that supports uneven sharding across devices. It builds on PyTorch’s DTensor while adding flexibility for cases where different ranks may have different local tensor sizes.

The example below shows how to create and work with ShardTensor:

import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import Shard
from physicsnemo.distributed import DistributedManager
from physicsnemo.distributed.shard_tensor import ShardTensor, scatter_tensor

def main():
    # Initialize distributed environment
    DistributedManager.initialize()
    dm = DistributedManager()

    # Create a 1D device mesh - by default, a -1 will use all devices
    # (For a 2D mesh, -1 will work to infer a single dimension in a mesh tensor)
    mesh = dm.initialize_mesh((-1,), mesh_dim_names=["spatial"])

    # Create a tensor on rank 0
    if dist.rank == 0:
        tensor = torch.randn(100, 64)
    else:
        tensor = None

    # Scatter the tensor across devices with uneven sharding
    # This will automatically determine appropriate local sizes
    sharded = scatter_tensor(
        tensor,
        global_src=0,
        mesh=mesh,
        placements=(Shard(0),)  # Shard along first dimension
    )

    # Work with local portions
    local_tensor = sharded.to_local()

    # Redistribute to different sharding scheme
    new_sharded = sharded.redistribute(
        placements=(Shard(1),)  # Change to shard along second dimension
    )

How does this work?#

ShardTensor extends PyTorch’s DTensor to support uneven sharding where different ranks can have different local tensor sizes. It tracks shard size information and handles redistribution between different sharding schemes while maintaining gradient flow.

Key differences from DTensor include: - Support for uneven sharding where ranks have different local sizes - Tracking and propagation of shard size information - Custom collective operations optimized for uneven sharding - Flexible redistribution between different sharding schemes

Operations work by:

  1. Converting inputs to local tensors

  2. Performing operations locally

  3. Constructing new ShardTensor with appropriate sharding

  4. Handling any needed communication between ranks

ShardTensor#

class physicsnemo.distributed.shard_tensor.ShardTensor(
local_tensor: Tensor,
spec: ShardTensorSpec,
*,
requires_grad: bool,
)[source]#

Bases: DTensor

A class similar to pytorch’s native DTensor but with more flexibility for uneven data sharding.

Leverages very similar API to DTensor (identical, where possible) but deliberately tweaking routines to avoid implicit assumptions about tensor sharding.

The key differences from DTensor are: - Supports uneven sharding where different ranks can have different local tensor sizes - Tracks and propagates shard size information across operations - Handles redistribution of unevenly sharded tensors - Provides custom collective operations optimized for uneven sharding

Like DTensor, operations are dispatched through PyTorch’s dispatcher system. Most operations work by: 1. Converting inputs to local tensors 2. Performing the operation locally 3. Constructing a new ShardTensor with appropriate sharding spec 4. Handling any needed communication between ranks

The class provides methods for: - Converting to/from local tensors - Redistributing between different sharding schemes - Performing collective operations like all_gather and reduce_scatter - Basic tensor operations that maintain sharding information

backward(*args, **kwargs)[source]#

Backward pass for ShardTensor.

This method is used to perform the backward pass for a ShardTensor. It handles the redistribution of the tensor to the desired placements and then calls the backward pass on the local tensor.

classmethod from_dtensor(
dtensor: DTensor,
) ShardTensor[source]#

Convert a DTensor to a ShardTensor. We assume the DTensor is properly constructed.

Parameters:

dtensor – DTensor to convert

Returns:

Equivalent ShardTensor

static from_local(
local_tensor: Tensor,
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
sharding_shapes: str | Dict[int, List[Tuple[int, ...]]] = 'infer',
) ShardTensor[source]#

Generate a new ShardTensor from local torch tensors. Uses device mesh and placements to infer global tensor properties.

No restriction is made on forcing tensors to have equal shapes locally. Instead, the requirement is that tensor shapes could be concatenated into a single tensor according to the placements.

Parameters:
  • local_tensor – Local chunk of tensor. All participating tensors must be of the same rank and concatable across the mesh dimensions

  • device_mesh – Target Device Mesh, if not specified will use the current mesh

  • placements – Target placements, must have same number of elements as device_mesh.ndim

  • sharding_shapes – Controls how shard tensor spec is generated: - “chunk”: Use torch.chunk shapes to infer shapes from global shape (no communication) - “infer”: Use collective communication to infer shapes from mesh neighbors. - Manual dict mapping mesh dim to list of shard shapes: Use provided shapes. Must pass on each rank!

Returns:

A new ShardTensor instance

full_tensor(
*,
grad_placements: Sequence[Placement] | None = None,
) Tensor[source]#

Need to re-implement here to ensure a ShardTensor is used as the output of redistribute.

offsets(mesh_dim: int | None = None) List[int][source]#

Get offsets of shards along a mesh dimension.

Parameters:

mesh_dim – Mesh dimension to get offsets for. If None, returns all offsets.

Returns:

List of offsets for shards along specified dimension

classmethod patches_enabled() bool[source]#

Whether to enable patches for this class.

Default is False, but can be changed by the user.

redistribute(
device_mesh: DeviceMesh | None = None,
placements: Sequence[Placement] | None = None,
*,
async_op: bool = False,
) ShardTensor[source]#

Redistribute tensor across device mesh with new placement scheme. Like DTensor redistribute but uses custom layer for shard redistribution.

Parameters:
  • device_mesh – Target device mesh. Uses current if None.

  • placements – Target placement scheme. Required.

  • async_op – Whether to run asynchronously

Returns:

Redistributed ShardTensor

Raises:

RuntimeError – If placements not specified or invalid

classmethod register_dispatch_handler(
op: OpOverload,
handler: Callable,
) None[source]#

Register a handler for a specific PyTorch operator in the dispatch system.

classmethod register_function_handler(
func: Callable,
handler: Callable,
) None[source]#

Register a handler for a Python-level function or method.

classmethod register_named_function_handler(
func_name: str,
handler: Callable,
) None[source]#

Register a named function that has been named via torch.library.custom_op

to_local(
*,
grad_placements: Sequence[Placement] | None = None,
) Tensor[source]#

Get local tensor from this ShardTensor.

Parameters:

grad_placements – Future layout of gradients. Optional.

Returns:

Local torch.Tensor. Shape may vary between ranks for sharded tensors.

Utility Functions#

physicsnemo.distributed.shard_tensor.scatter_tensor(
tensor: Tensor,
global_src: int,
mesh: DeviceMesh,
placements: Tuple[Placement, ...],
global_shape: Size | None = None,
dtype: dtype | None = None,
requires_grad: bool = False,
) ShardTensor[source]#

Take a tensor from source rank and distribute it across devices on the mesh according to placements.

This function takes a tensor that exists on a single source rank and distributes it across a device mesh according to the specified placement scheme. For multi-dimensional meshes, it performs a flattened scatter operation before constructing the sharded tensor.

Parameters:
  • tensor – The tensor to distribute, must exist on source rank

  • global_src – Global rank ID of the source process

  • mesh – Device mesh defining the process topology

  • placements – Tuple of placement specifications defining how to distribute the tensor

Returns:

The distributed tensor with specified placements

Return type:

ShardTensor

Raises:

ValueError – If global_src is not an integer or not in the mesh

Why do we need this?#

During deep learning training, memory usage can grow significantly when working with large input data, even if the model itself is relatively small. This is because many operations create intermediate tensors that temporarily consume memory.

For example, consider a 2D convolution operation on a high-resolution image. If we have a batch of 1024x1024 images, even a simple 3x3 convolution needs to save the entire input image in memory for computing the gradients in the backward pass.

For high resolution images, this can easily lead to out of memory errors as model depth grows, even if the number of parameters is small - this is a significant contrast from LLM model training, where the memory usage is dominated by the number of parameters and the corresponding optimizer states. In software solutions like DeepSpeed and ZeRO, this is handled by partitioning the model across GPUs, but this is not a solution for large-input applications.

ShardTensor helps address this by: - Distributing the input data across multiple devices - Performing operations on smaller local portions - Coordinating the necessary communication between devices in the forward and backward passes

ShardTensor is built as an extension of PyTorch’s DTensor, and gains substantial functionality by leveraging the utilities already implemented in the PyTorch distributed package. However, some operations on sharded input data are not trivial to implement correctly, nor relevant to the model sharding problem. In PhysicsNeMo, we have implemented parallelized versions of several key operations, including (so far):

  • Convolution (1D, 2D, 3D)

  • Neighborhood Attention (2D)

These operations are implemented in the physicsnemo.distributed.shard_utils module, and are enabled by dynamically intercepting calls to (for example) torch.nn.functional.conv2d. When the function is called with ShardTensor inputs, the operation is automatically parallelized across the mesh associated with the input. When the function is called with non-ShardTensor inputs, the operation is executed in a non-parallelized manner, exactly as expected.

To enable these operations, you must import patch_operations from physicsnemo.distributed.shard_utils. This will patch the relevant functions in the distributed package to support ShardTensor inputs.

We are continuing to add more operations, and contributions are welcome!