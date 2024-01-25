Distributed utilites in Modulus are designed to simplify implementation of parallel training and make inference scripts easier by providing a unified way to configure and query parameters associated with the distributed environment. The utilites in modulus.distributed build on top of the utilites from torch.distributed and abstract out some of the complexities of setting up a distributed execution environment.

The example below shows how to setup a simple distributed data parallel training recipe using the distributed utilites in Modulus. DistributedDataParallel in PyTorch provides the framework for data parallel training by reducing parameter gradients across multiple worker processes after the backwards pass. The code below shows how to specify the device_ids , output_device , broadcast_buffers and find_unused_parameters arguments of the DistributedDataParallel utility using the DistributedManager .

Copy Copied! import torch from torch.nn.parallel import DistributedDataParallel from modulus.distributed import DistributedManager from modulus.models.mlp.fully_connected import FullyConnected def main(): # Initialize the DistributedManager. This will automatically # detect the number of processes the job was launched with and # set those configuration parameters appropriately. Currently # torchrun (or any other pytorch compatible launcher), mpirun (OpenMPI) # and SLURM based launchers are supported. DistributedManager.initialize() # Since this is a singleton class, you can just get an instance # of it anytime after initialization and not need to reinitialize # each time. dist = DistributedManager() # Set up model on the appropriate device. DistributedManager # figures out what device should be used on this process arch = FullyConnected(in_features=32, out_features=64).to(dist.device) # Set up DistributedDataParallel if using more than a single process. # The `distributed` property of DistributedManager can be used to # check this. if dist.distributed: ddps = torch.cuda.Stream() with torch.cuda.stream(ddps): arch = DistributedDataParallel( arch, device_ids=[dist.local_rank], # Set the device_id to be # the local rank of this process on # this node output_device=dist.device, broadcast_buffers=dist.broadcast_buffers, find_unused_parameters=dist.find_unused_parameters, ) torch.cuda.current_stream().wait_stream(ddps) # Set up the optimizer optimizer = torch.optim.Adam( arch.parameters(), lr=0.001, ) def training_step(input, target): pred = arch(invar) loss = torch.sum(torch.pow(pred - target, 2)) loss.backward() optimizer.step() return loss # Sample training loop for i in range(20): # Random inputs and targets for simplicity input = torch.randn(128, 32, device=dist.device) target = torch.randn(128, 64, device=dist.device) # Training step loss = training_step(input, target) if __name__ == "__main__": main()

This training script can be run on a single GPU using python train.py or on multiple GPUs using

Copy Copied! torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> train.py

or

Copy Copied! mpirun -np <num_gpus> python train.py

if using OpenMPI. The script can also be run on a SLURM cluster using