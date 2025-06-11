Let’s look now at a slightly more complicated example: the dot product of two vectors. In this case, because the output is a single scalar, we’ll find that there is communication required and see how to implement that seamlessly with ShardTensor .

Here’s the single-device implementation. Note that the only difference here is in the definition of f :

Copy Copied! def f(x, y): return torch.dot(x, y)

For reference, here’s the full code:

Example 1: Vector Dot Product, single device



Copy Copied! import torch import time # Make a really big tensor: N = 1_000_000_000 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") a = torch.randn(N, device=device) b = torch.randn(N, device=device) def f(a, b): # This is a truly non-local operation: full reduction is needed. return torch.dot(a, b) # run a couple times to warmup: for i in range(5): c = f(a,b) # Optional: Benchmark it if you like: # Measure execution time torch.cuda.synchronize() start_time = time.time() for i in range(10): c = f(a,b) torch.cuda.synchronize() end_time = time.time() elapsed_time = end_time - start_time print(f"Execution time for 10 runs:{elapsed_time:.4f}seconds")





If we make the same changes to the distributed version, we get an error when we run it (as of torch 2.6!):

Copy Copied! NotImplementedError: Operator aten.dot.default does not have a sharding strategy registered.

This is a good time to talk about how PyTorch decides what to do each time it’s called for an operation on torch.Tensor(s) , which will lead into how we fix this error.

PyTorch, as you likely already know, implements operations on multiple backends and with multiple paths for execution. How does it decide which path to use, when you call an operation on a tensor? The answer lies in the PyTorch __torch_function__ and __torch_dispatch__ interface.

There are many resources, more detailed and more correct than this (for example, see this blog post or this one and especially the official walkthrough), but here is a high level overview: function routing is built on input types, rather than functions themselves. So when you call a function with an object (like ShardTensor) that extends the PyTorch torch.Tensor interface, you can use __torch_function__ and __torch_dispatch__ to capture and reroute operations to custom implementations.

For built in functions to PyTorch, this is simply a matter of registering a pair of functions with ShardTensor : the function you want to intercept, and the function you want to route data to instead (as long as at least one argument is a ShardTensor ). We’ll see this in action below, but in the case of functions that torch does not know about (external functions, user functions, etc.), we can tap into this system manually.

With all that in mind, let’s add a handler for torch.dot that works on PhysicsNeMo’s ShardTensor :

Copy Copied! from physicsnemo.distributed import DistributedManager, scatter_tensor, ShardTensor from torch.distributed.tensor.placement_types import Shard, Replicate def sharded_dot_product(func, types, args, kwargs): # NOTE: all functions overloaded and used by __torch_function__ will have # the same input signature. You can use python argument unpacking to # extract what you need: def extract_args(x, y, *args, **kwargs): return x, y x, y = extract_args(*args, **kwargs) # Each tensor has a _spec attribute, which contains information about the tensor's placement # and the devices it lives on: x_spec = x._spec y_spec = y._spec # It's usually good to ensure the tensor placements work: if not x_spec.placements == y_spec.placements: raise NotImplementedError("Tensors must be sharded on the same device") if not x_spec.mesh == y_spec.mesh: raise NotImplementedError("Tensors must be sharded on the same mesh") # And, you might want to check placements are valid in more complex cases # Extract the mesh - we'll want it for the all reduce: mesh = x_spec.mesh # This is a straightforward implementation, for clarity # Get the local values of each tensor: local_x = x.to_local() local_y = y.to_local() # This is a purely single-gpu operation: local_dot_product = torch.dot(local_x, local_y) # If you wanted to write a generic sharding handler for this type of operation, # you could do: # local_dot_product = func(local_x, local_y) # But it's over kill here... # SUM_Reduce the local result across all ranks: dist.all_reduce(local_dot_product, op=dist.ReduceOp.SUM, group=mesh.get_group()) # We do want to return the result as a ShardTensor, for consistency. # We can easily create one on the same mesh as a "Replicated" tensor: output = ShardTensor.from_local( local_tensor = local_dot_product, device_mesh = mesh, placements = (Replicate(),) ) return output # Don't forget to register it with ShardTensor: ShardTensor.register_function_handler(torch.dot, sharded_dot_product)

Once you have registered a path for ShardTensor to do this computation, you can run the same code as before and it should work out of the box. For completeness, here’s the full code:

Example 1: Vector Dot Product, distributed computation



Copy Copied! import torch import torch.distributed as dist import time from physicsnemo.distributed import DistributedManager, scatter_tensor, ShardTensor from torch.distributed.tensor.placement_types import Shard, Replicate def sharded_dot_product(func: Callable, types: Tuple, args: Tuple, kwargs: Dict): """ Overload for torch.dot to support sharded tensors. This function enables mutli-gpu dot product operations on ShardTensors, by computing the dot product locally on each rank and then summin across all GPUs. Requires the placements and mesh to agree across the two tensors. This is tutorial code: it does not handle all cases and you should not use it in production. Note the function signature: we are using this function in the __torch_function__ protocol and it has to follow the specific signature requirements. Args: func (Callable): The function to overload (e.g., torch.dot). types (Tuple): Tuple of types passed by __torch_function__ protocol. args (Tuple): Positional arguments passed to the function. kwargs (Dict): Keyword arguments passed to the function. In general, torch will use the values in `types` to determine which path of execution to take. In this function, we don't have to worry about that as much because it's already selected for execution. """ # NOTE: all functions overloaded and used by __torch_function__ will have # the same input signature. You can use python argument unpacking to # extract what you need: def extract_args(x, y, *args, **kwargs): return x, y x, y = extract_args(*args, **kwargs) # Each tensor has a _spec attribute, which contains information about the tensor's placement # and the devices it lives on: x_spec = x._spec y_spec = y._spec # IT'S usually good to ensure the tensor placements work: if not x_spec.placements == y_spec.placements: raise NotImplementedError("Tensors must be sharded on the same device") if not x_spec.mesh == y_spec.mesh: raise NotImplementedError("Tensors must be sharded on the same mesh") # And, you might want to check placements are valid in more complex cases # Extract the mesh - we'll want it for the all reduce: mesh = x_spec.mesh # This is a straightforward implementation, for clarity # Get the local values of each tensor: local_x = x.to_local() local_y = y.to_local() # This is a purely single-gpu operation: local_dot_product = torch.dot(local_x, local_y) # If you wanted to write a generic sharding handler for this type of operation, # you could do: # local_dot_product = func(local_x, local_y) # But it's over kill here... # SUM_Reduce the local result across all ranks: dist.all_reduce(local_dot_product, op=dist.ReduceOp.SUM, group=mesh.get_group()) # We do want to return the result as a ShardTensor, for consistency. # We can easily create one on the same mesh as a "Replicated" tensor: # The output placements are now Replicated, not sharded. We have used all_reduce # to sum the local results across all ranks, and each rank has the full data - # exactly what the Replicate() placement expects. # (Even though it's a scalar output, we still have to specify a placement) output = ShardTensor.from_local( local_tensor = local_dot_product, device_mesh = mesh, placements = (Replicate(),) ) return output # Register the implementation with ShardTensor's function dispatch: ShardTensor.register_function_handler(torch.dot, sharded_dot_product) # Another really big tensor: N = 1_000_000_000 DistributedManager.initialize() dm = DistributedManager() device = dm.device a = torch.randn(N, device=device) b = torch.randn(N, device=device) def f(x, y): return torch.dot(x , y) # Get the baseline result c_baseline = f(a,b) # DeviceMesh is a pytorch object - you can initialize it directly, or for added # flexibility physicsnemo can infer up to one mesh dimension for you # (as a -1, like in a tensor.reshape() call...) mesh = dm.initialize_mesh(mesh_shape = [-1,], mesh_dim_names = ["domain"]) # Shard(i) indicates we want the final tensor to be sharded along the tensor dimension i # But the placements is a tuple or list, indicating the desired placement along the mesh. placements = (Shard(0),) # This function will distribute the tensor from global_src to the specified mesh, # using the input placements. # Note that in multi-level parallelism, the source is the _global_ rank not the mesh group rank. a_sharded = scatter_tensor(tensor = a, global_src = 0, mesh = mesh, placements = placements) b_sharded = scatter_tensor(tensor = b, global_src = 0, mesh = mesh, placements = placements) c_sharded = f(a_sharded,b_sharded) # Comparison requires that we coalesce the results: c_sharded = c_sharded.full_tensor() # Now, performance measurement: # Warm up: for i in range(5): c = f(a_sharded,b_sharded) # Measure execution time torch.cuda.synchronize() start_time = time.time() for i in range(10): c = f(a_sharded,b_sharded) torch.cuda.synchronize() end_time = time.time() elapsed_time = end_time - start_time if dm.rank == 0: print(f"Rank{dm.rank}, Tensor agreement?{torch.allclose(c_baseline, c_sharded)}") print(f"Execution time for 10 runs:{elapsed_time:.4f}seconds")



