In scientific AI, one of most challenging aspects in training a model is dealing with extremely high resolution data. In this tutorial, we’ll explore what makes high resolution data so challenging to handle, for both training and inference, and why that’s different from the scaling challenges in other domains (like NLP, image processing, etc.). We’ll also take a technical look at how we’re working to streamline high-resolution model training in PhysicsNeMo , and how you can leverage our tools for your own scientific workloads as well.

To understand why scientific AI hits unique challenges in training and inference on high resolution data, let’s take a look at the computational and memory cost of training models and subsequently running inference. “Cost” here refers to two fundamental, high level concepts: computational cost is how much computing power is needed to complete an operation (and is, in general, a complicated interplay of GPU FLOPs, memory bandwidth, cache sizes, algorithm efficiencies, and more); memory costs refer to the amount of GPU memory required to perform the computations.

For all AI models, the memory cost of inference is dominated by just two categories of use:

Model parameters (weights, biases, encodings, etc.) all are required to be loaded into GPU memory for fast access during inference. For a model with N total parameters, each parameter requires 4 bytes in float32 precision, or 2 in float16/bfloat16. A rough approximation is that a 100M parameter model requires 400MB of memory in float32 precision. For Large Language Models with billions of parameters, even at inference time this is a large amount of memory. Active Data (the inputs and outputs!) represent the memory required to actually compute the layers and outputs of the model. For inference, the available memory has to be enough to hold the input data, output data, and model parameters as well as temporariliy accommodate memory of intermediate activations. As one layer’s output is consumed by the next layer, the total memory needed typically never exceeds the requirements of the most memory-intensive layer.

For scientific AI with high resolution data, the memory cost at inference can be dominated not by the model parameters but by the data - though it’s not always a clear cut winner.

During training (for a standard training loop), the high resolution of the data is even more challenging. There are two additional memory consumers during a model training, in most cases:

Optimizer states (gradients, moments) are needed to accumulate and update the model’s parameters during training. This can be as little memory usage as the model’s parameters, again, for SGD. For more complicated optimizers, like adam , the optimizer must store moments and running gradient averages and the usage increases. Activations For each layer during training, pytorch will typically save the some version of the layer’s input, output, or other component as the “intermediate activation” for that layer. In practice, this is a computational optimization to enable the backwards pass to compute and propagate gradients more efficiently. Each layer, however, requires extra memory storage during training that is proportional to the resolution of the input data.

As a cumulative effect, as models continue to stack up layers and save intermediate activations, the activation-related memory required training a model grows with both the depth of the model and the resolution of the input data. In contrast to Large Language Models, where the memory usage during training is dominated by the parameters, gradients, and optimizer states, for high resolution scientific AI models with modest parameter counts the memory usage is dominated by actications!

To address this challenge, in PhysicsNeMo we have developed a domain-parallelism framework specifically designed to parallelize the high compute and memory costs of training and inferencing models on high resolution data. Named ShardTensor , and built on top of PyTorch’s DTensor framework, ShardTensor allows models to divide expensive operations across multiple GPUs - parallelizing both the compute required as well as the storage of the intermediate activations.

The remainder of this tutorial will focus on the high level concepts of ShardTensor and domain parallelism, and Implementing new layers for ShardTensor will be covered in a separate tutorial.

As a high level example, let’s consider a simple 2D convolution operation. There have been many tutorials on the mathematics and efficient computation of convolutions; let’s not focus on that here. Instead, consider if the input data to the convolution is spread across two GPUs, and we want to correctly compute the ouput of the convolution but without ever coalescing the input data on a single GPU.

Just applying the convolution to each half provides incorrect results. We can simulate this, actually, in pytorch on one device:

Copy Copied! import torch full_image = torch.randn(1, 8, 1024, 1024) left_image = full_image[:,:,:512,:] right_image = full_image[:,:,512:,:] convolution_operator = torch.nn.Conv2d(8, 8, 3, stride=1, padding=1) full_output = convolution_operator(full_image) left_output = convolution_operator(left_image) right_output = convolution_operator(right_image) recombined_output = torch.cat([left_output, right_output], dim=2) # Do the shapes agree? print(full_output.shape) print(recombined_output.shape) # (they should!) # Do the values agree? torch.allclose(full_output, recombined_output) # (they do not!)

To understand why they don’t agree, we can look at the location of the disagreement:

Copy Copied! diff = full_output - recombined_output b_locs, c_locs, h_locs, w_locs = torch.where( torch.abs(diff) > 1e-6) print(torch.unique(b_locs)) print(torch.unique(c_locs)) print(torch.unique(h_locs)) print(torch.unique(w_locs))

This will produce the following output:

Copy Copied! tensor([0]) tensor([0, 1, 2, 3, 4, 5, 6, 7]) tensor([511, 512]) tensor([ 0, 1, 2, ..., 1021, 1022, 1023])

We see in particular that along the height dimension (dim=2), the output is incorrect only along the pixels 511 and 512 - right where we split the data! The problem is that the convolution operator is a local operation, but splitting the data prevents it from seeing the correct neighboring pixels right at the border. You could fix this directly:

Copy Copied! # Slice off the data needed on the other image (around the center of the original image) missing_left_data = right_image[:,:,0:1,:] missing_right_data = left_image[:,:,-1:,:] # Add it to the correct image padded_left_image = torch.cat([left_image, missing_left_data], 2) padded_right_image = torch.cat([missing_right_data, right_image], 2) # Recompute convolutions right_output = convolution_operator(padded_right_image)[:,:,1:,:] left_output = convolution_operator(padded_left_image)[:,:,:-1,:] # ^ Need to drop the extra pixels in the output here recombined_output = torch.cat([left_output, right_output], dim=2) # Now, the output works correctly: torch.allclose(recombined_output, full_output) # True

In the example above, for a simple convolution, we saw that just splitting the data and applying the base operation didn’t give the results we needed. In general, this is true of many operations we see in AI models: splitting the data across GPUs requires extra operations or communication, depending on the operation, to get everything right. We also haven’t even mentioned the gradients yet - to call backward() through this split operation across devices also requires extra operations and communication. But, in order to get the memory and potential computational benefits of domain parallelism, it’s necessary.

PyTorch’s DTensor interface already has an interface for a distributed tensor mechanism, and it’s great - great enough, in fact, that ShardTensor is built upon it. However, DTensor is built with a different paradigm of parallelism in mind, including model parallelisms from DeepSpeed and MegaTron - which is supported in pytorch via Fully Sharded Data Parallelism. It has several shortcomings: notably, it can not accommodate data that isn’t distributed uniformly or according to torch.chunk syntax. For scientific data, such as mesh data, point clouds, or anything else irregular, this is a nearly-immediate dead end for deploying domain parallelism. Further, DTensor ’s mechanism for implementing parallelism is largely restricted to lower level torch operations - great for broad support in PyTorch, but not as accesible for most developers.

With ShardTensor , we extend the functionality of DTensor in the ways needed to make domain parallelism simpler and easier to apply. In practice, this looks like the following, if we reuse the convolution example from before:

Example of domain parallel convolution with ShardTensor



Copy Copied! import torch from torch.distributed.tensor import ( Shard, distribute_module, ) from physicsnemo.distributed import ( DistributedManager, ShardTensor, scatter_tensor, ) DistributedManager.initialize() dm = DistributedManager() ########################### # Single GPU - Create input ########################### original_tensor = torch.randn(1, 8, 1024, 1024, device=dm.device, requires_grad=True) ########################################### # Single GPU - Create a single-layer model: ########################################### conv = torch.nn.Conv2d(8, 8, 3, stride=1, padding=1).to(dm.device) ######################################## # Single GPU - forward + loss + backward ######################################## single_gpu_output = conv(original_tensor) # This isn't really a loss, just a pretend one that's scalar! single_gpu_output.mean().backward() # Copy the gradients produced here - so we don't overwrite them later. original_tensor_grad = original_tensor.grad.data.clone() #################### # Single GPU - DONE! #################### ################# # Sharded - Setup ################# # DeviceMesh is a pytorch object - you can initialize it directly, or for added # flexibility physicsnemo can infer up to one mesh dimension for you # (as a -1, like in a tensor.reshape() call...) mesh = dm.initialize_mesh(mesh_shape=(-1,), mesh_dim_names=("domain_parallel",)) # A mesh, by the way, refers to devices and not data: it's a mesh of connected # GPUs in this case, and the python DeviceMesh can be reused as many times as needed. # That said, it can be decomposed similar to a tensor - multiple mesh axes, and # you can axis sub-meshes. Each mesh also has ways to access process groups # for targeted collectives. ########################### # Sharded - Distribute Data ########################### # This is now a tensor across all GPUs, spread on the "height" dimension == 2 # In general, to create a ShardTensor (or DTensor) you need to specify placements. # Placements must be a list or tuple of `Shard()` or `Replicate()` objects # from torch.distributed.tensor. # # Each index in the tuple represents the placement over the corresponding mesh dimension # (so, mesh.ndim == len(placements)! ) # `Shard()` takes an argument representing the **tensor** index that is sharded. # So below, the tensor is sharded over the tensor dimension 2 on the mesh dimension 0. sharded_tensor = scatter_tensor(original_tensor, 0, mesh, (Shard(2),), requires_grad=True) ################################ # Sharded - distribute the model ################################ # We tell pytorch that the convolution will work on distributed tensors: # And, over the same mesh! distributed_conv = distribute_module(conv, mesh) ##################################### # Sharded - forward + loss + backward ##################################### # Now, we can do the distributed convolution: sharded_output = distributed_conv(sharded_tensor) sharded_output.mean().backward() ############################################ # Sharded - gather up outputs to all devices ############################################ # This triggers a collective allgather. full_output = sharded_output.full_tensor() full_grad = sharded_tensor.grad.full_tensor() ################# # Accuracy Checks ################# if dm.rank == 0: # Only check on rank 0 because we used it's data and weights for the sharded tensor. # Check that the output is the same as the single-device output: assert torch.allclose(full_output, single_gpu_output) print(f"Global operation matches local! ") # Check that the gradient is correct: assert torch.allclose(original_tensor_grad, full_grad) print(f"Gradient check passed!") print(f"Distributed grad sharding and local shape:{sharded_tensor.grad._spec.placements},{sharded_tensor.grad.to_local().shape}")





If you run this ( torchrun --nproc-per-node 4 conv_example.py ), you’ll see the checks on output and gradients both pass. Further, the last line will print:

Copy Copied! Distributed grad sharding and local shape: (Shard(dim=2),), torch.Size([1, 8, 256, 1024])

Note that when running this, there was no need to perform manual communication or padding, in either the forward or backward pass. And, though we used a convolution, the details of the operation didn’t need to be explicitly specified. In this case, it just worked.

At a high level, DTensor from pytorch is a concept of a local chunk of a tensor (stored as a torch.Tensor ), and a DTensorSpec object which combines a DeviceMesh object representing the group of GPUs the tensor is on, and a description of how that global tensor is distributed (or replicated). ShardTensor extends this API with an addition to the specification to track the shape of each local tensor along sharding axes. This becomes important when the input data is something like a point cloud, rather than an evenly-distributed tensor.

At run time, when an operation in torch has DTensor as input, pytorch will use a custom dispatcher in DTensor to route perform operations correctly on the inputs. ShardTensor extends this by intercepting a little higher than DTensor : operations can be intercepted at the functional level, or at the dispatch level, and if ShardTensor has no registered implementation it will fall back to DTensor.

ShardTensor also has dedicated implementations of common reduction operations sum and mean , in order to properly intercept and distribute gradients correctly. This is why, in the example above, you can seamlessly call mean().backward() on a ShardTensor and the gradients will arrive to their proper sharding. No need to do anything special - reducing a ShardTensor will handle this automatically.