Simple Training and Inference recipe
In this tutorial, we will see how to use utilites from Modulus to setup a simple model training pipeline. Once the initial setup is complete, we will look into optimizing the training loop, and also run it in a distributed fashion. We will finish the tutorial with an inference workflow that will demonstrate how to use Modulus models in inference.
Let’s get started. For the purposes of this tutorial, we will focus more on the Modulus utilities and not the correctness of the problem definition or the results. A typical training workflow requires data, a trainable model and an optimizer to update the model parameters.
Using built-in models
In this example, we will look at different ways one can interact with Models in Modulus. Modulus presents a library of models suitable for Physics-ML applications for you to use directly in your training workflows. In this tutorial we will see how to use a simple model in Modulus to setup a data-driven training. Using the models from Modulus will enable us to use various other Modulus features like optimization and quality-of-life functionalites like checkpointing and model entrypoints.
Later we will also see how to customize these models in Modulus.
In this example we will use the FNO model from Modulus. To demonstrate the training using this model, we would need some dataset to train the model. To allow for fast prototyping of models, Modulus provides a set of benchmark datasets that can be used out of the box without the need to setup data-loading pipelines. In this example, we will use one such datapipe called Darcy2D to get the training data.
Let’s start with importing a few utils and packages.
import torch
import modulus
from modulus.datapipes.benchmarks.darcy import Darcy2D
from modulus.metrics.general.mse import mse
from modulus.models.fno.fno import FNO
In this example we want to develop a mapping between the permeability and its subsequent pressure field for a given forcing function. Refer Modulus Datapipes for additional details.
Then a simple training loop for this example can be written as follows:
normaliser = {
"permeability": (1.25, 0.75),
"darcy": (4.52e-2, 2.79e-2),
}
dataloader = Darcy2D(
resolution=256, batch_size=64, nr_permeability_freq=5, normaliser=normaliser
)
model = FNO(
in_channels=1,
out_channels=1,
decoder_layers=1,
decoder_layer_size=32,
dimension=2,
latent_channels=32,
num_fno_layers=4,
num_fno_modes=12,
padding=5,
).to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda step: 0.85**step
)
# run for 20 iterations
for i in range(20):
batch = next(iter(dataloader))
true = batch["darcy"]
pred = model(batch["permeability"])
loss = mse(pred, true)
loss.backward()
optimizer.step()
scheduler.step()
print(f"Iteration:{i}. Loss:{loss.detach().cpu().numpy()}")
That’s it! This shows how to use a model from Modulus. Most of the models in Modulus are highly configurable allowing you to use them out-of-the-box for different applications. Refer Modulus Models for a more complete list of available models.
Using custom models in Modulus
Modulus provides a lot of pre-built optimized models. However, there might be times where the shipped models might not serve your application. In such cases, you can easily write your own models and have them interact with the other Modulus utilites and features. Modulus uses PyTorch in the backend and most Modulus models are, at the core, PyTorch models. In this section we will see how to go from a typical PyTorch model to a Modulus model.
Let’s get started with the same application of Darcy problem. Let’s write a simple UNet to solve the problem. A simple PyTorch model for a UNet can be written as shown below:
import torch.nn as nn
import modulus
from modulus.datapipes.benchmarks.darcy import Darcy2D
from modulus.metrics.general.mse import mse
class UNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super(UNet, self).__init__()
self.enc1 = self.conv_block(in_channels, 64)
self.enc2 = self.conv_block(64, 128)
self.dec1 = self.upconv_block(128, 64)
self.dec2 = self.upconv_block(64, 32)
self.final = nn.Conv2d(32, out_channels, kernel_size=1)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
)
def upconv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.enc1(x)
x = self.enc2(x)
x = self.dec1(x)
x = self.dec2(x)
return self.final(x)
Let’s now convert this to a Modulus Model. Modulus provides Module
class that is
designed to be a drop-in replacement for the torch.nn.module
. Along with that, you
need to also pass a MetaData
that captures the optimizations and other features
supported by the model. Using the Module
subclass allows using these optimizations,
and other features like checkpointing etc. from Modulus.
Thus, converting a PyTorch model to a Modulus model is very simple. For the above model, the diff would look something like below:
- import torch.nn as nn
+ from dataclasses import dataclass
+ from modulus.models.meta import ModelMetaData
+ from modulus.models.module import Module
- class UNet(nn.Module):
+ @dataclass
+ class MetaData(ModelMetaData):
+ name: str = "UNet"
+ # Optimization
+ jit: bool = False
+ cuda_graphs: bool = True
+ amp_cpu: bool = True
+ amp_gpu: bool = True
+
+ class UNet(Module):
def __init__(self, in_channels=1, out_channels=1):
- super(UNet, self).__init__()
+ super(UNet, self).__init__(meta=MetaData())
self.enc1 = self.conv_block(in_channels, 64)
self.enc2 = self.conv_block(64, 128)
With simple changes like this you can convert a PyTorch model to a Modulus Model!
The optimizations are not automatically applied. The user is responsible for writing the model with the optimizations supported. However, if the models supports the optimization and the same is captured in the MetaData, then the downstream features will work out-of-the-box.
For utilizing the checkpointing functionality of Modulus, the Model instantiation arguments must be json serializable.
You can also use a Modulus model as a standard PyTorch model as they are interoperable.
Let’s say you don’t want to make changes to the code, but you have a PyTorch model
already. You can convert it to a Modulus model by using the modulus.Module.from_torch
method. This is described in detail in Converting PyTorch Models to Modulus Models.
from dataclasses import dataclass
import torch.nn as nn
from modulus.models.meta import ModelMetaData
from modulus.models.module import Module
@dataclass
class MdlsUNetMetaData(ModelMetaData):
name: str = "MdlsUNet"
# Optimization
jit: bool = False
cuda_graphs: bool = True
amp_cpu: bool = True
amp_gpu: bool = True
MdlsUNet = Module.from_torch(UNet, meta=MdlsUNetMetaData)
And just like that you can use your existing PyTorch model as a Modulus Model.
A very similar process can be followed to convert a Modulus model to a Modulus Sym model
so that you can use the Constraints and other defitions from the Modulus Sym repository.
Here you will use the Arch
class from Modulus Sym that provides utilites and methods
to go from a tensor data to a dict format which Modulus Sym uses.
from typing import Dict, Optional
from modulus.sym.key import Key
from modulus.sym.models.arch import Arch
class MdlsSymUNet(Arch):
def __init__(
self,
input_keys=[Key("a")],
output_keys=[Key("b")],
in_channels=1,
out_channels=1,
):
super(MdlsSymUNet, self).__init__(
input_keys=input_keys, output_keys=output_keys
)
self.mdls_model = MdlsUNet(in_channels, out_channels) # MdlsUNet defined above
def forward(self, dict_tensor: Dict[str, torch.Tensor]):
x = self.concat_input(
dict_tensor,
self.input_key_dict,
detach_dict=None,
dim=1,
)
out = self.mdls_model(x)
return self.split_output(out, self.output_key_dict, dim=1)
Once we have a model defined in the Modulus style, we can use the optimizations
like AMP, CUDA Graphs, and JIT using the modulus.utils.StaticCaptureTraining
decorator.
This decorator will capture the training step function and optimize it for the specified
optimizations.
The StaticCaptureTraining
decorator is still under development and may be
refactored in the future.
import time
from modulus.utils import StaticCaptureTraining
normaliser = {
"permeability": (1.25, 0.75),
"darcy": (4.52e-2, 2.79e-2),
}
dataloader = Darcy2D(
resolution=256, batch_size=8, nr_permeability_freq=5, normaliser=normaliser
)
model = MdlsUNet().to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda step: 0.85**step
)
# Create training step function with optimization wrapper
# StaticCaptureTraining calls `backward` on the loss and
# `optimizer.step()` so you don't have to do that
# explicitly.
@StaticCaptureTraining(
model=model,
optim=optimizer,
cuda_graph_warmup=11,
)
def training_step(invar, outvar):
predvar = model(invar)
loss = mse(predvar, outvar)
return loss
# run for 20 iterations
for i in range(20):
batch = next(iter(dataloader))
true = batch["darcy"]
input = batch["permeability"]
loss = training_step(input, true)
scheduler.step()
Modulus has several Distributed utilites to simplify the implementation of parallel training and make inference scripts easier by providing a unified way to configure and query parameters associated with distributed environment.
In this example, we will see how to convert our existing workflow to use data-parallelism. For an deep-dive on Modulus Distributed utilities, refer Modulus Distributed.
def main():
# Initialize the DistributedManager. This will automatically
# detect the number of processes the job was launched with and
# set those configuration parameters appropriately.
DistributedManager.initialize()
# Get instance of the DistributedManager
dist = DistributedManager()
normaliser = {
"permeability": (1.25, 0.75),
"darcy": (4.52e-2, 2.79e-2),
}
dataloader = Darcy2D(
resolution=256, batch_size=64, nr_permeability_freq=5, normaliser=normaliser
)
model = FNO(
in_channels=1,
out_channels=1,
decoder_layers=1,
decoder_layer_size=32,
dimension=2,
latent_channels=32,
num_fno_layers=4,
num_fno_modes=12,
padding=5,
).to(dist.device)
# Set up DistributedDataParallel if using more than a single process.
if dist.distributed:
ddps = torch.cuda.Stream()
with torch.cuda.stream(ddps):
model = DistributedDataParallel(
model,
device_ids=[dist.local_rank], # Set the device_id to be the local rank of this process on this node
output_device=dist.device,
broadcast_buffers=dist.broadcast_buffers,
find_unused_parameters=dist.find_unused_parameters,
)
torch.cuda.current_stream().wait_stream(ddps)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda step: 0.85**step
)
# Create training step function with optimization wrapper
# StaticCaptureTraining calls `backward` on the loss and
# `optimizer.step()` so you don't have to do that
# explicitly.
@StaticCaptureTraining(
model=model,
optim=optimizer,
cuda_graph_warmup=11,
)
def training_step(invar, outvar):
predvar = model(invar)
loss = mse(predvar, outvar)
return loss
# run for 20 iterations
for i in range(20):
batch = next(iter(dataloader))
true = batch["darcy"]
input = batch["permeability"]
loss = training_step(input, true)
scheduler.step()
if __name__ == "__main__":
main()
Running inference on trained model is simple! This is shown by the code below.
model = FNO(
in_channels=1,
out_channels=1,
decoder_layers=1,
decoder_layer_size=32,
dimension=2,
latent_channels=32,
num_fno_layers=4,
num_fno_modes=12,
padding=5,
).to("cuda")
# Save the checkpoint. For demo, we will just save untrained checkpoint,
# but in typical workflows is saved after model training.
model.save("untrained_checkpoint.mdlus")
# Inference code
# The parameters to instantitate the model will be loaded from the checkpoint
model_inf = modulus.Module.from_checkpoint("untrained_checkpoint.mdlus").to("cuda")
# put the model in evaluation mode
model_inf.eval()
# run inference
with torch.inference_mode():
input = torch.ones(8, 1, 256, 256).to("cuda")
output = model_inf(input)
print(output.shape)
The static capture and distributed utilities can also be used during inference for speeding up the inference workflow, but that is out of the scope for this tutorial.