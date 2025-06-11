The preamble to the training script has an important patch to make sure that the conv2d operation works with ShardTensor :

Copy Copied! import torch # This is necessary to patch Conv2d to work with ShardTensor from physicsnemo.distributed.shard_utils import patch_operations import torch.nn as nn from physicsnemo.distributed import DistributedManager from physicsnemo.distributed.shard_tensor import ShardTensor from torch.distributed.tensor import distribute_module, distribute_tensor from torch.distributed.tensor.placement_types import Shard, Replicate from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

Next, setup the distributed environment including the device mesh. Here we do it globally, but you can do it locally as well and pass device_mesh objects around.