Simple Training and Inference recipe

In this tutorial, we will see how to use utilites from Modulus to setup a simple model training pipeline. Once the initial setup is complete, we will look into optimizing the training loop, and also run it in a distributed fashion. We will finish the tutorial with an inference workflow that will demonstrate how to use Modulus models in inference.

Let’s get started. For the purposes of this tutorial, we will focus more on the Modulus utilities and not the correctness of the problem definition or the results. A typical training workflow requires data, a trainable model and an optimizer to update the model parameters.

Using built-in models

In this example, we will look at different ways one can interact with Models in Modulus. Modulus presents a library of models suitable for Physics-ML applications for you to use directly in your training workflows. In this tutorial we will see how to use a simple model in Modulus to setup a data-driven training. Using the models from Modulus will enable us to use various other Modulus features like optimization and quality-of-life functionalites like checkpointing and model entrypoints.

Later we will also see how to customize these models in Modulus.

In this example we will use the FNO model from Modulus. To demonstrate the training using this model, we would need some dataset to train the model. To allow for fast prototyping of models, Modulus provides a set of benchmark datasets that can be used out of the box without the need to setup data-loading pipelines. In this example, we will use one such datapipe called Darcy2D to get the training data.

Let’s start with importing a few utils and packages.

Copy
Copied!
            

import torch import modulus from modulus.datapipes.benchmarks.darcy import Darcy2D from modulus.metrics.general.mse import mse from modulus.models.fno.fno import FNO

In this example we want to develop a mapping between the permeability and its subsequent pressure field for a given forcing function. Refer Modulus Datapipes for additional details.

Then a simple training loop for this example can be written as follows:

Copy
Copied!
            

normaliser = { "permeability": (1.25, 0.75), "darcy": (4.52e-2, 2.79e-2), } dataloader = Darcy2D( resolution=256, batch_size=64, nr_permeability_freq=5, normaliser=normaliser ) model = FNO( in_channels=1, out_channels=1, decoder_layers=1, decoder_layer_size=32, dimension=2, latent_channels=32, num_fno_layers=4, num_fno_modes=12, padding=5, ).to("cuda") optimizer = torch.optim.Adam(model.parameters(), lr=0.01) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: 0.85**step ) # run for 20 iterations for i in range(20): batch = next(iter(dataloader)) true = batch["darcy"] pred = model(batch["permeability"]) loss = mse(pred, true) loss.backward() optimizer.step() scheduler.step() print(f"Iteration:{i}. Loss:{loss.detach().cpu().numpy()}")

That’s it! This shows how to use a model from Modulus. Most of the models in Modulus are highly configurable allowing you to use them out-of-the-box for different applications. Refer Modulus Models for a more complete list of available models.

Using custom models in Modulus

Modulus provides a lot of pre-built optimized models. However, there might be times where the shipped models might not serve your application. In such cases, you can easily write your own models and have them interact with the other Modulus utilites and features. Modulus uses PyTorch in the backend and most Modulus models are, at the core, PyTorch models. In this section we will see how to go from a typical PyTorch model to a Modulus model.

Let’s get started with the same application of Darcy problem. Let’s write a simple UNet to solve the problem. A simple PyTorch model for a UNet can be written as shown below:

Copy
Copied!
            

import torch.nn as nn import modulus from modulus.datapipes.benchmarks.darcy import Darcy2D from modulus.metrics.general.mse import mse class UNet(nn.Module): def __init__(self, in_channels=1, out_channels=1): super(UNet, self).__init__() self.enc1 = self.conv_block(in_channels, 64) self.enc2 = self.conv_block(64, 128) self.dec1 = self.upconv_block(128, 64) self.dec2 = self.upconv_block(64, 32) self.final = nn.Conv2d(32, out_channels, kernel_size=1) def conv_block(self, in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), ) def upconv_block(self, in_channels, out_channels): return nn.Sequential( nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2), nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True), ) def forward(self, x): x = self.enc1(x) x = self.enc2(x) x = self.dec1(x) x = self.dec2(x) return self.final(x)

Let’s now convert this to a Modulus Model. Modulus provides Module class that is designed to be a drop-in replacement for the torch.nn.module. Along with that, you need to also pass a MetaData that captures the optimizations and other features supported by the model. Using the Module subclass allows using these optimizations, and other features like checkpointing etc. from Modulus.

Thus, converting a PyTorch model to a Modulus model is very simple. For the above model, the diff would look something like below:

Copy
Copied!
            

- import torch.nn as nn + from dataclasses import dataclass + from modulus.models.meta import ModelMetaData + from modulus.models.module import Module - class UNet(nn.Module): + @dataclass + class MetaData(ModelMetaData): + name: str = "UNet" + # Optimization + jit: bool = False + cuda_graphs: bool = True + amp_cpu: bool = True + amp_gpu: bool = True + + class UNet(Module): def __init__(self, in_channels=1, out_channels=1): - super(UNet, self).__init__() + super(UNet, self).__init__(meta=MetaData()) self.enc1 = self.conv_block(in_channels, 64) self.enc2 = self.conv_block(64, 128)

With simple changes like this you can convert a PyTorch model to a Modulus Model!

Note

The optimizations are not automatically applied. The user is responsible for writing the model with the optimizations supported. However, if the models supports the optimization and the same is captured in the MetaData, then the downstream features will work out-of-the-box.

Note

For utilizing the checkpointing functionality of Modulus, the Model instantiation arguments must be json serializable.

You can also use a Modulus model as a standard PyTorch model as they are interoperable.

Let’s say you don’t want to make changes to the code, but you have a PyTorch model already. You can convert it to a Modulus model by using the modulus.Module.from_torch method. This is described in detail in Converting PyTorch Models to Modulus Models.

Copy
Copied!
            

from dataclasses import dataclass import torch.nn as nn from modulus.models.meta import ModelMetaData from modulus.models.module import Module @dataclass class MdlsUNetMetaData(ModelMetaData): name: str = "MdlsUNet" # Optimization jit: bool = False cuda_graphs: bool = True amp_cpu: bool = True amp_gpu: bool = True MdlsUNet = Module.from_torch(UNet, meta=MdlsUNetMetaData)

And just like that you can use your existing PyTorch model as a Modulus Model. A very similar process can be followed to convert a Modulus model to a Modulus Sym model so that you can use the Constraints and other defitions from the Modulus Sym repository. Here you will use the Arch class from Modulus Sym that provides utilites and methods to go from a tensor data to a dict format which Modulus Sym uses.

Copy
Copied!
            

from typing import Dict, Optional from modulus.sym.key import Key from modulus.sym.models.arch import Arch class MdlsSymUNet(Arch): def __init__( self, input_keys=[Key("a")], output_keys=[Key("b")], in_channels=1, out_channels=1, ): super(MdlsSymUNet, self).__init__( input_keys=input_keys, output_keys=output_keys ) self.mdls_model = MdlsUNet(in_channels, out_channels) # MdlsUNet defined above def forward(self, dict_tensor: Dict[str, torch.Tensor]): x = self.concat_input( dict_tensor, self.input_key_dict, detach_dict=None, dim=1, ) out = self.mdls_model(x) return self.split_output(out, self.output_key_dict, dim=1)

Once we have a model defined in the Modulus style, we can use the optimizations like AMP, CUDA Graphs, and JIT using the modulus.utils.StaticCaptureTraining decorator. This decorator will capture the training step function and optimize it for the specified optimizations.

Note

The StaticCaptureTraining decorator is still under development and may be refactored in the future.

Copy
Copied!
            

import time from modulus.utils import StaticCaptureTraining normaliser = { "permeability": (1.25, 0.75), "darcy": (4.52e-2, 2.79e-2), } dataloader = Darcy2D( resolution=256, batch_size=8, nr_permeability_freq=5, normaliser=normaliser ) model = MdlsUNet().to("cuda") optimizer = torch.optim.Adam(model.parameters(), lr=0.01) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: 0.85**step ) # Create training step function with optimization wrapper # StaticCaptureTraining calls `backward` on the loss and # `optimizer.step()` so you don't have to do that # explicitly. @StaticCaptureTraining( model=model, optim=optimizer, cuda_graph_warmup=11, ) def training_step(invar, outvar): predvar = model(invar) loss = mse(predvar, outvar) return loss # run for 20 iterations for i in range(20): batch = next(iter(dataloader)) true = batch["darcy"] input = batch["permeability"] loss = training_step(input, true) scheduler.step()

Modulus has several Distributed utilites to simplify the implementation of parallel training and make inference scripts easier by providing a unified way to configure and query parameters associated with distributed environment.

In this example, we will see how to convert our existing workflow to use data-parallelism. For an deep-dive on Modulus Distributed utilities, refer Modulus Distributed.

Copy
Copied!
            

def main(): # Initialize the DistributedManager. This will automatically # detect the number of processes the job was launched with and # set those configuration parameters appropriately. DistributedManager.initialize() # Get instance of the DistributedManager dist = DistributedManager() normaliser = { "permeability": (1.25, 0.75), "darcy": (4.52e-2, 2.79e-2), } dataloader = Darcy2D( resolution=256, batch_size=64, nr_permeability_freq=5, normaliser=normaliser ) model = FNO( in_channels=1, out_channels=1, decoder_layers=1, decoder_layer_size=32, dimension=2, latent_channels=32, num_fno_layers=4, num_fno_modes=12, padding=5, ).to(dist.device) # Set up DistributedDataParallel if using more than a single process. if dist.distributed: ddps = torch.cuda.Stream() with torch.cuda.stream(ddps): model = DistributedDataParallel( model, device_ids=[ dist.local_rank ], # Set the device_id to be the local rank of this process on this node output_device=dist.device, broadcast_buffers=dist.broadcast_buffers, find_unused_parameters=dist.find_unused_parameters, ) torch.cuda.current_stream().wait_stream(ddps) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: 0.85**step ) # Create training step function with optimization wrapper # StaticCaptureTraining calls `backward` on the loss and # `optimizer.step()` so you don't have to do that # explicitly. @StaticCaptureTraining( model=model, optim=optimizer, cuda_graph_warmup=11, ) def training_step(invar, outvar): predvar = model(invar) loss = mse(predvar, outvar) return loss # run for 20 iterations for i in range(20): batch = next(iter(dataloader)) true = batch["darcy"] input = batch["permeability"] loss = training_step(input, true) scheduler.step() if __name__ == "__main__": main()

Running inference on trained model is simple! This is shown by the code below.

Copy
Copied!
            

model = FNO( in_channels=1, out_channels=1, decoder_layers=1, decoder_layer_size=32, dimension=2, latent_channels=32, num_fno_layers=4, num_fno_modes=12, padding=5, ).to("cuda") # Save the checkpoint. For demo, we will just save untrained checkpoint, # but in typical workflows is saved after model training. model.save("untrained_checkpoint.mdlus") # Inference code # The parameters to instantitate the model will be loaded from the checkpoint model_inf = modulus.Module.from_checkpoint("untrained_checkpoint.mdlus").to("cuda") # put the model in evaluation mode model_inf.eval() # run inference with torch.inference_mode(): input = torch.ones(8, 1, 256, 256).to("cuda") output = model_inf(input) print(output.shape)

The static capture and distributed utilities can also be used during inference for speeding up the inference workflow, but that is out of the scope for this tutorial.

Previous NVIDIA Modulus Core (Latest Release)
Next Simple Logging and Checkpointing recipe
© Copyright 2023, NVIDIA Modulus Team. Last updated on Apr 19, 2024.