Modulus Models
Modulus contains its own Model class for constructing neural networks. This model class
is built on top of PyTorch’s nn.Module
and can be used interchangeably within the
PyTorch ecosystem. Using Modulus models allows you to leverage various features of
Modulus aimed at improving performance and ease of use. These features include, but are
not limited to, model zoo, automatic mixed-precision, CUDA Graphs, and easy checkpointing.
We discuss each of these features in the following sections.
Modulus contains several optimized, customizable and easy-to-use models. These include some very general models like Fourier Neural Operators (FNOs), ResNet, and Graph Neural Networks (GNNs) as well as domain-specific models like Deep Learning Weather Prediction (DLWP) and Spherical Fourier Neural Operators (SFNO).
For a list of currently available models, please refer the models on GitHub.
Below are some simple examples of how to use these models.
>>> import torch
>>> from modulus.models.mlp.fully_connected import FullyConnected
>>> model = FullyConnected(in_features=32, out_features=64)
>>> input = torch.randn(128, 32)
>>> output = model(input)
>>> output.shape
torch.Size([128, 64])
>>> import torch
>>> from modulus.models.fno.fno import FNO
>>> model = FNO(
in_channels=4,
out_channels=3,
decoder_layers=2,
decoder_layer_size=32,
dimension=2,
latent_channels=32,
num_fno_layers=2,
padding=0,
)
>>> input = torch.randn(32, 4, 32, 32) #(N, C, H, W)
>>> output = model(input)
>>> output.size()
torch.Size([32, 3, 32, 32])
There are a few different ways to construct a Modulus model. If you are a seasoned PyTorch user, the easiest way would be to write your model using the optimized layers and utilities from Modulus or Pytorch. Lets take a look at a simple example of a UNet model first showing a simple PyTorch implementation and then a Modulus implementation that supports CUDA Graphs and Automatic Mixed-Precision.
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super(UNet, self).__init__()
self.enc1 = self.conv_block(in_channels, 64)
self.enc2 = self.conv_block(64, 128)
self.dec1 = self.upconv_block(128, 64)
self.final = nn.Conv2d(64, out_channels, kernel_size=1)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
def upconv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(x1)
x = self.dec1(x2)
return self.final(x)
Now we show this model rewritten in Modulus. First, let’s subclass the model from
modulus.Module
instead of torch.nn.Module
. The
modulus.Module
class acts like a direct replacement for the
torch.nn.Module
and provides additional functionality for saving and loading
checkpoints, etc. Refer to the API docs of modulus.Module
for further
details. Additionally we will add metadata to the model to capture the optimizations
that this model supports. In this case we will enable CUDA Graphs and Automatic Mixed-Precision.
from dataclasses import dataclass
import modulus
import torch.nn as nn
@dataclass
class UNetMetaData(modulus.ModelMetaData):
name: str = "UNet"
# Optimization
jit: bool = True
cuda_graphs: bool = True
amp_cpu: bool = True
amp_gpu: bool = True
class UNet(modulus.Module):
def __init__(self, in_channels=1, out_channels=1):
super(UNet, self).__init__(meta=UNetMetaData())
self.enc1 = self.conv_block(in_channels, 64)
self.enc2 = self.conv_block(64, 128)
self.dec1 = self.upconv_block(128, 64)
self.final = nn.Conv2d(64, out_channels, kernel_size=1)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
def upconv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(x1)
x = self.dec1(x2)
return self.final(x)
Now that we have our Modulus model, we can make use of these optimizations using the
modulus.utils.StaticCaptureTraining
decorator. This decorator will capture the
training step function and optimize it for the specified optimizations.
import torch
from modulus.utils import StaticCaptureTraining
model = UNet().to("cuda")
input = torch.randn(8, 1, 128, 128).to("cuda")
output = torch.zeros(8, 1, 64, 64).to("cuda")
optim = torch.optim.Adam(model.parameters(), lr=0.001)
# Create training step function with optimization wrapper
# StaticCaptureTraining calls `backward` on the loss and
# `optimizer.step()` so you don't have to do that
# explicitly.
@StaticCaptureTraining(
model=model,
optim=optim,
cuda_graph_warmup=11,
)
def training_step(invar, outvar):
predvar = model(invar)
loss = torch.sum(torch.pow(predvar - outvar, 2))
return loss
# Sample training loop
for i in range(20):
# In place copy of input and output to support cuda graphs
input.copy_(torch.randn(8, 1, 128, 128).to("cuda"))
output.copy_(torch.zeros(8, 1, 64, 64).to("cuda"))
# Run training step
loss = training_step(input, output)
For the simple model above, you can observe ~1.1x speed-up due to CUDA Graphs and AMP. The speed-up observed changes from model to model and is typically greater for more complex models.
The ModelMetaData
and modulus.Module
do not make the model
support CUDA Graphs, AMP, etc. optimizations automatically. The user is responsible
to write the model code that enables each of these optimizations.
Models in the Modulus Model Zoo are written to support many of these optimizations
and checked against Modulus’s CI to ensure that they work correctly.
The StaticCaptureTraining
decorator is still under development and may be
refactored in the future.
In the above example we show constructing a Modulus model from scratch. However you
can also convert existing PyTorch models to Modulus models in order to leverage
Modulus features. To do this, you can use the Module.from_torch
method as shown
below.
from dataclasses import dataclass
import modulus
import torch.nn as nn
class TorchModel(nn.Module):
def __init__(self):
super(TorchModel, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = self.conv1(x)
return self.conv2(x)
@dataclass
class ConvMetaData(ModelMetaData):
name: str = "UNet"
# Optimization
jit: bool = True
cuda_graphs: bool = True
amp_cpu: bool = True
amp_gpu: bool = True
ModulusModel = modulus.Module.from_torch(TorchModel, meta=ConvMetaData())
As mentioned above, Modulus models are interoperable with PyTorch models. This means that
you can save and load Modulus models using the standard PyTorch APIs however, we provide
a few additional utilities to make this process easier. A key challenge in saving and
loading models is keeping track of the model metadata such as layer sizes, etc. Modulus
models can be saved with this metadata to a custom .mdlus
file. These files allow
for easy loading and instantiation of the model. We show two examples of this below.
The first example shows saving and loading a model from an already instantiated model.
>>> from modulus.models.mlp.fully_connected import FullyConnected
>>> model = FullyConnected(in_features=32, out_features=64)
>>> model.save("model.mdlus") # Save model to .mdlus file
>>> model.load("model.mdlus") # Load model weights from .mdlus file from already instantiated model
>>> model
FullyConnected(
(layers): ModuleList(
(0): FCLayer(
(activation_fn): SiLU()
(linear): Linear(in_features=32, out_features=512, bias=True)
)
(1-5): 5 x FCLayer(
(activation_fn): SiLU()
(linear): Linear(in_features=512, out_features=512, bias=True)
)
)
(final_layer): FCLayer(
(activation_fn): Identity()
(linear): Linear(in_features=512, out_features=64, bias=True)
)
)
The second example shows loading a model from a .mdlus
file without having to
instantiate the model first. We note that in this case we don’t know the class or
parameters to pass to the constructor of the model. However, we can still load the
model from the .mdlus
file.
>>> from modulus import Module
>>> fc_model = Module.from_checkpoint("model.mdlus") # Instantiate model from .mdlus file.
>>> fc_model
FullyConnected(
(layers): ModuleList(
(0): FCLayer(
(activation_fn): SiLU()
(linear): Linear(in_features=32, out_features=512, bias=True)
)
(1-5): 5 x FCLayer(
(activation_fn): SiLU()
(linear): Linear(in_features=512, out_features=512, bias=True)
)
)
(final_layer): FCLayer(
(activation_fn): Identity()
(linear): Linear(in_features=512, out_features=64, bias=True)
)
)
In order to make use of this functionality, the model must have json serializable
inputs to the __init__
function. It is highly recommended that all Modulus
models be developed with this requirement in mind.
Modulus contains a model registry that allows for easy access and ingestion of models. Below is a simple example of how to use the model registry to obtain a model class.
>>> from modulus.registry import ModelRegistry
>>> model_registry = ModelRegistry()
>>> model_registry.list_models()
['AFNO', 'DLWP', 'FNO', 'FullyConnected', 'GraphCastNet', 'MeshGraphNet', 'One2ManyRNN', 'Pix2Pix', 'SFNO', 'SRResNet']
>>> FullyConnected = model_registry.factory("FullyConnected")
>>> model = FullyConnected(in_features=32, out_features=64)
The model registry also allows exposing models via entry points. This allows for
integration of models into the Modulus ecosystem. For example, suppose you have a
package MyPackage
that contains a model MyModel
. You can expose this model
to the Modulus registry by adding an entry point to your toml
file. For
example, suppose your package structure is as follows:
# setup.py
from setuptools import setup, find_packages
setup()
# pyproject.toml
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "MyPackage"
description = "My Neural Network Zoo."
version = "0.1.0"
[project.entry-points."modulus.models"]
MyModulusModel = "mypackage.models.MyModulusModel:MyModulusModel"
# mypackage/models.py
import torch.nn as nn
from modulus.models import Model
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = self.conv1(x)
return self.conv2(x)
MyModulusModel = Model.from_pytorch(MyModel)
Once this package is installed, you can access the model via the Modulus model registry.
>>> from modulus.registry import ModelRegistry
>>> model_registry = ModelRegistry()
>>> model_registry.list_models()
['MyModulusModel', 'AFNO', 'DLWP', 'FNO', 'FullyConnected', 'GraphCastNet', 'MeshGraphNet', 'One2ManyRNN', 'Pix2Pix', 'SFNO', 'SRResNet']
>>> MyModulusModel = model_registry.factory("MyModulusModel")
For more information on entry points and potential use cases, see this blog post.
Fully Connected Network
- class modulus.models.mlp.fully_connected.FullyConnected(*args, **kwargs)[source]
Bases:
Module
A densely-connected MLP architecture
- Parameters
in_features (int, optional) – Size of input features, by default 512
layer_size (int, optional) – Size of every hidden layer, by default 512
out_features (int, optional) – Size of output features, by default 512
num_layers (int, optional) – Number of hidden layers, by default 6
activation_fn (Union[str, List[str]], optional) – Activation function to use, by default ‘silu’
skip_connections (bool, optional) – Add skip connections every 2 hidden layers, by default False
adaptive_activations (bool, optional) – Use an adaptive activation function, by default False
weight_norm (bool, optional) – Use weight norm on fully connected layers, by default False
weight_fact (bool, optional) – Use weight factorization on fully connected layers, by default False
Example
>>> model = modulus.models.mlp.FullyConnected(in_features=32, out_features=64) >>> input = torch.randn(128, 32) >>> output = model(input) >>> output.size() torch.Size([128, 64])
- forward(x: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.mlp.fully_connected.MetaData(name: str = 'FullyConnected', jit: bool = True, cuda_graphs: bool = True, amp: bool = True, amp_cpu: bool = None, amp_gpu: bool = None, torch_fx: bool = True, bf16: bool = False, onnx: bool = True, onnx_gpu: bool = None, onnx_cpu: bool = None, onnx_runtime: bool = True, trt: bool = False, var_dim: int = -1, func_torch: bool = True, auto_grad: bool = True)[source]
Bases:
ModelMetaData
Fourier Neural Operators
- class modulus.models.fno.fno.FNO(*args, **kwargs)[source]
Bases:
Module
Fourier neural operator (FNO) model.
NoteThe FNO architecture supports options for 1D, 2D, 3D and 4D fields which can be controlled using the dimension parameter.
- Parameters
in_channels (int) – Number of input channels
out_channels (int) – Number of output channels
decoder_layers (int, optional) – Number of decoder layers, by default 1
decoder_layer_size (int, optional) – Number of neurons in decoder layers, by default 32
decoder_activation_fn (str, optional) – Activation function for decoder, by default “silu”
dimension (int) – Model dimensionality (supports 1, 2, 3).
latent_channels (int, optional) – Latent features size in spectral convolutions, by default 32
num_fno_layers (int, optional) – Number of spectral convolutional layers, by default 4
num_fno_modes (Union[int, List[int]], optional) – Number of Fourier modes kept in spectral convolutions, by default 16
padding (int, optional) – Domain padding for spectral convolutions, by default 8
padding_type (str, optional) – Type of padding for spectral convolutions, by default “constant”
activation_fn (str, optional) – Activation function, by default “gelu”
coord_features (bool, optional) – Use coordinate grid as additional feature map, by default True
Example
>>> # define the 2d FNO model >>> model = modulus.models.fno.FNO( ... in_channels=4, ... out_channels=3, ... decoder_layers=2, ... decoder_layer_size=32, ... dimension=2, ... latent_channels=32, ... num_fno_layers=2, ... padding=0, ... ) >>> input = torch.randn(32, 4, 32, 32) #(N, C, H, W) >>> output = model(input) >>> output.size() torch.Size([32, 3, 32, 32])
NoteReference: Li, Zongyi, et al. “Fourier neural operator for parametric partial differential equations.” arXiv preprint arXiv:2010.08895 (2020).
- forward(x: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.fno.fno.FNO1DEncoder(in_channels: int = 1, num_fno_layers: int = 4, fno_layer_size: int = 32, num_fno_modes: Union[int, List[int]] = 16, padding: Union[int, List[int]] = 8, padding_type: str = 'constant', activation_fn: Module = GELU(approximate='none'), coord_features: bool = True)[source]
Bases:
Module
1D Spectral encoder for FNO
- Parameters
in_channels (int, optional) – Number of input channels, by default 1
num_fno_layers (int, optional) – Number of spectral convolutional layers, by default 4
fno_layer_size (int, optional) – Latent features size in spectral convolutions, by default 32
num_fno_modes (Union[int, List[int]], optional) – Number of Fourier modes kept in spectral convolutions, by default 16
padding (Union[int, List[int]], optional) – Domain padding for spectral convolutions, by default 8
padding_type (str, optional) – Type of padding for spectral convolutions, by default “constant”
activation_fn (nn.Module, optional) – Activation function, by default nn.GELU
coord_features (bool, optional) – Use coordinate grid as additional feature map, by default True
- build_fno(num_fno_modes: List[int]) → None[source]
construct FNO block. :param num_fno_modes: Number of Fourier modes kept in spectral convolutions :type num_fno_modes: List[int]
- build_lift_network() → None[source]
construct network for lifting variables to latent space.
- forward(x: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- grid_to_points(value: Tensor) → Tuple[Tensor, List[int]][source]
converting from grid based (image) to point based representation
- Parameters
value (Meshgrid tensor) –
- Returns
Tensor, meshgrid shape
- Return type
Tuple
- meshgrid(shape: List[int], device: device) → Tensor[source]
Creates 1D meshgrid feature
- Parameters
shape (List[int]) – Tensor shape
device (torch.device) – Device model is on
- Returns
Meshgrid tensor
- Return type
Tensor
- points_to_grid(value: Tensor, shape: List[int]) → Tensor[source]
converting from point based to grid based (image) representation
- Parameters
value (Tensor) – Tensor
shape (List[int]) – meshgrid shape
- Returns
Meshgrid tensor
- Return type
Tensor
- class modulus.models.fno.fno.FNO2DEncoder(in_channels: int = 1, num_fno_layers: int = 4, fno_layer_size: int = 32, num_fno_modes: Union[int, List[int]] = 16, padding: Union[int, List[int]] = 8, padding_type: str = 'constant', activation_fn: Module = GELU(approximate='none'), coord_features: bool = True)[source]
Bases:
Module
2D Spectral encoder for FNO
- Parameters
in_channels (int, optional) – Number of input channels, by default 1
num_fno_layers (int, optional) – Number of spectral convolutional layers, by default 4
fno_layer_size (int, optional) – Latent features size in spectral convolutions, by default 32
num_fno_modes (Union[int, List[int]], optional) – Number of Fourier modes kept in spectral convolutions, by default 16
padding (Union[int, List[int]], optional) – Domain padding for spectral convolutions, by default 8
padding_type (str, optional) – Type of padding for spectral convolutions, by default “constant”
activation_fn (nn.Module, optional) – Activation function, by default nn.GELU
coord_features (bool, optional) – Use coordinate grid as additional feature map, by default True
- build_fno(num_fno_modes: List[int]) → None[source]
construct FNO block. :param num_fno_modes: Number of Fourier modes kept in spectral convolutions :type num_fno_modes: List[int]
- build_lift_network() → None[source]
construct network for lifting variables to latent space.
- forward(x: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- grid_to_points(value: Tensor) → Tuple[Tensor, List[int]][source]
converting from grid based (image) to point based representation
- Parameters
value (Meshgrid tensor) –
- Returns
Tensor, meshgrid shape
- Return type
Tuple
- meshgrid(shape: List[int], device: device) → Tensor[source]
Creates 2D meshgrid feature
- Parameters
shape (List[int]) – Tensor shape
device (torch.device) – Device model is on
- Returns
Meshgrid tensor
- Return type
Tensor
- points_to_grid(value: Tensor, shape: List[int]) → Tensor[source]
converting from point based to grid based (image) representation
- Parameters
value (Tensor) – Tensor
shape (List[int]) – meshgrid shape
- Returns
Meshgrid tensor
- Return type
Tensor
- class modulus.models.fno.fno.FNO3DEncoder(in_channels: int = 1, num_fno_layers: int = 4, fno_layer_size: int = 32, num_fno_modes: Union[int, List[int]] = 16, padding: Union[int, List[int]] = 8, padding_type: str = 'constant', activation_fn: Module = GELU(approximate='none'), coord_features: bool = True)[source]
Bases:
Module
3D Spectral encoder for FNO
- Parameters
in_channels (int, optional) – Number of input channels, by default 1
num_fno_layers (int, optional) – Number of spectral convolutional layers, by default 4
fno_layer_size (int, optional) – Latent features size in spectral convolutions, by default 32
num_fno_modes (Union[int, List[int]], optional) – Number of Fourier modes kept in spectral convolutions, by default 16
padding (Union[int, List[int]], optional) – Domain padding for spectral convolutions, by default 8
padding_type (str, optional) – Type of padding for spectral convolutions, by default “constant”
activation_fn (nn.Module, optional) – Activation function, by default nn.GELU
coord_features (bool, optional) – Use coordinate grid as additional feature map, by default True
- build_fno(num_fno_modes: List[int]) → None[source]
construct FNO block. :param num_fno_modes: Number of Fourier modes kept in spectral convolutions :type num_fno_modes: List[int]
- build_lift_network() → None[source]
construct network for lifting variables to latent space.
- forward(x: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- grid_to_points(value: Tensor) → Tuple[Tensor, List[int]][source]
converting from grid based (image) to point based representation
- Parameters
value (Meshgrid tensor) –
- Returns
Tensor, meshgrid shape
- Return type
Tuple
- meshgrid(shape: List[int], device: device) → Tensor[source]
Creates 3D meshgrid feature
- Parameters
shape (List[int]) – Tensor shape
device (torch.device) – Device model is on
- Returns
Meshgrid tensor
- Return type
Tensor
- points_to_grid(value: Tensor, shape: List[int]) → Tensor[source]
converting from point based to grid based (image) representation
- Parameters
value (Tensor) – Tensor
shape (List[int]) – meshgrid shape
- Returns
Meshgrid tensor
- Return type
Tensor
- class modulus.models.fno.fno.FNO4DEncoder(in_channels: int = 1, num_fno_layers: int = 4, fno_layer_size: int = 32, num_fno_modes: Union[int, List[int]] = 16, padding: Union[int, List[int]] = 8, padding_type: str = 'constant', activation_fn: Module = GELU(approximate='none'), coord_features: bool = True)[source]
Bases:
Module
4D Spectral encoder for FNO
- Parameters
in_channels (int, optional) – Number of input channels, by default 1
num_fno_layers (int, optional) – Number of spectral convolutional layers, by default 4
fno_layer_size (int, optional) – Latent features size in spectral convolutions, by default 32
num_fno_modes (Union[int, List[int]], optional) – Number of Fourier modes kept in spectral convolutions, by default 16
padding (Union[int, List[int]], optional) – Domain padding for spectral convolutions, by default 8
padding_type (str, optional) – Type of padding for spectral convolutions, by default “constant”
activation_fn (nn.Module, optional) – Activation function, by default nn.GELU
coord_features (bool, optional) – Use coordinate grid as additional feature map, by default True
- build_fno(num_fno_modes: List[int]) → None[source]
construct FNO block. :param num_fno_modes: Number of Fourier modes kept in spectral convolutions :type num_fno_modes: List[int]
- build_lift_network() → None[source]
construct network for lifting variables to latent space.
- forward(x: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- grid_to_points(value: Tensor) → Tuple[Tensor, List[int]][source]
converting from grid based (image) to point based representation
- Parameters
value (Meshgrid tensor) –
- Returns
Tensor, meshgrid shape
- Return type
Tuple
- meshgrid(shape: List[int], device: device) → Tensor[source]
Creates 4D meshgrid feature
- Parameters
shape (List[int]) – Tensor shape
device (torch.device) – Device model is on
- Returns
Meshgrid tensor
- Return type
Tensor
- points_to_grid(value: Tensor, shape: List[int]) → Tensor[source]
converting from point based to grid based (image) representation
- Parameters
value (Tensor) – Tensor
shape (List[int]) – meshgrid shape
- Returns
Meshgrid tensor
- Return type
Tensor
- class modulus.models.fno.fno.MetaData(name: str = 'FourierNeuralOperator', jit: bool = True, cuda_graphs: bool = True, amp: bool = False, amp_cpu: bool = None, amp_gpu: bool = None, torch_fx: bool = False, bf16: bool = False, onnx: bool = False, onnx_gpu: bool = False, onnx_cpu: bool = False, onnx_runtime: bool = False, trt: bool = False, var_dim: int = 1, func_torch: bool = False, auto_grad: bool = False)[source]
Bases:
ModelMetaData
- class modulus.models.afno.afno.AFNO(*args, **kwargs)[source]
Bases:
Module
Adaptive Fourier neural operator (AFNO) model.
NoteAFNO is a model that is designed for 2D images only.
- Parameters
inp_shape (List[int]) – Input image dimensions [height, width]
in_channels (int) – Number of input channels
out_channels (int) – Number of output channels
patch_size (List[int], optional) – Size of image patches, by default [16, 16]
embed_dim (int, optional) – Embedded channel size, by default 256
depth (int, optional) – Number of AFNO layers, by default 4
mlp_ratio (float, optional) – Ratio of layer MLP latent variable size to input feature size, by default 4.0
drop_rate (float, optional) – Drop out rate in layer MLPs, by default 0.0
num_blocks (int, optional) – Number of blocks in the block-diag frequency weight matrices, by default 16
sparsity_threshold (float, optional) – Sparsity threshold (softshrink) of spectral features, by default 0.01
hard_thresholding_fraction (float, optional) – Threshold for limiting number of modes used [0,1], by default 1
Example
>>> model = modulus.models.afno.AFNO( ... inp_shape=[32, 32], ... in_channels=2, ... out_channels=1, ... patch_size=(8, 8), ... embed_dim=16, ... depth=2, ... num_blocks=2, ... ) >>> input = torch.randn(32, 2, 32, 32) #(N, C, H, W) >>> output = model(input) >>> output.size() torch.Size([32, 1, 32, 32])
NoteReference: Guibas, John, et al. “Adaptive fourier neural operators: Efficient token mixers for transformers.” arXiv preprint arXiv:2111.13587 (2021).
- forward(x: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- forward_features(x: Tensor) → Tensor[source]
Forward pass of core AFNO
- class modulus.models.afno.afno.AFNO2DLayer(hidden_size: int, num_blocks: int = 8, sparsity_threshold: float = 0.01, hard_thresholding_fraction: float = 1, hidden_size_factor: int = 1)[source]
Bases:
Module
AFNO spectral convolution layer
- Parameters
hidden_size (int) – Feature dimensionality
num_blocks (int, optional) – Number of blocks used in the block diagonal weight matrix, by default 8
sparsity_threshold (float, optional) – Sparsity threshold (softshrink) of spectral features, by default 0.01
hard_thresholding_fraction (float, optional) – Threshold for limiting number of modes used [0,1], by default 1
hidden_size_factor (int, optional) – Factor to increase spectral features by after weight multiplication, by default 1
- forward(x: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.afno.afno.AFNOMlp(in_features: int, latent_features: int, out_features: int, activation_fn: Module = GELU(approximate='none'), drop: float = 0.0)[source]
Bases:
Module
Fully-connected Multi-layer perception used inside AFNO
- Parameters
in_features (int) – Input feature size
latent_features (int) – Latent feature size
out_features (int) – Output feature size
activation_fn (nn.Module, optional) – Activation function, by default nn.GELU
drop (float, optional) – Drop out rate, by default 0.0
- forward(x: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.afno.afno.Block(embed_dim: int, num_blocks: int = 8, mlp_ratio: float = 4.0, drop: float = 0.0, activation_fn: ~torch.nn.modules.module.Module = GELU(approximate='none'), norm_layer: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.normalization.LayerNorm'>, double_skip: bool = True, sparsity_threshold: float = 0.01, hard_thresholding_fraction: float = 1.0)[source]
Bases:
Module
AFNO block, spectral convolution and MLP
- Parameters
embed_dim (int) – Embedded feature dimensionality
num_blocks (int, optional) – Number of blocks used in the block diagonal weight matrix, by default 8
mlp_ratio (float, optional) – Ratio of MLP latent variable size to input feature size, by default 4.0
drop (float, optional) – Drop out rate in MLP, by default 0.0
activation_fn (nn.Module, optional) – Activation function used in MLP, by default nn.GELU
norm_layer (nn.Module, optional) – Normalization function, by default nn.LayerNorm
double_skip (bool, optional) – Residual, by default True
sparsity_threshold (float, optional) – Sparsity threshold (softshrink) of spectral features, by default 0.01
hard_thresholding_fraction (float, optional) – Threshold for limiting number of modes used [0,1], by default 1
- forward(x: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.afno.afno.MetaData(name: str = 'AFNO', jit: bool = False, cuda_graphs: bool = True, amp: bool = True, amp_cpu: bool = None, amp_gpu: bool = None, torch_fx: bool = False, bf16: bool = False, onnx: bool = False, onnx_gpu: bool = True, onnx_cpu: bool = False, onnx_runtime: bool = True, trt: bool = False, var_dim: int = 1, func_torch: bool = False, auto_grad: bool = False)[source]
Bases:
ModelMetaData
- class modulus.models.afno.afno.PatchEmbed(inp_shape: List[int], in_channels: int, patch_size: List[int] = [16, 16], embed_dim: int = 256)[source]
Bases:
Module
Patch embedding layer
Converts 2D patch into a 1D vector for input to AFNO
- Parameters
inp_shape (List[int]) – Input image dimensions [height, width]
in_channels (int) – Number of input channels
patch_size (List[int], optional) – Size of image patches, by default [16, 16]
embed_dim (int, optional) – Embedded channel size, by default 256
- forward(x: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.afno.modafno.Block(embed_dim: int, mod_dim: int, num_blocks: int = 8, mlp_ratio: float = 4.0, drop: float = 0.0, activation_fn: ~torch.nn.modules.module.Module = GELU(approximate='none'), norm_layer: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.normalization.LayerNorm'>, double_skip: bool = True, sparsity_threshold: float = 0.01, hard_thresholding_fraction: float = 1.0, modulate_filter: bool = True, modulate_mlp: bool = True, scale_shift_mode: ~typing.Literal['complex', 'real'] = 'real')[source]
Bases:
Module
AFNO block, spectral convolution and MLP
- Parameters
embed_dim (int) – Embedded feature dimensionality
mod_dim (int) – Modululation input dimensionality
num_blocks (int, optional) – Number of blocks used in the block diagonal weight matrix, by default 8
mlp_ratio (float, optional) – Ratio of MLP latent variable size to input feature size, by default 4.0
drop (float, optional) – Drop out rate in MLP, by default 0.0
activation_fn (nn.Module, optional) – Activation function used in MLP, by default nn.GELU
norm_layer (nn.Module, optional) – Normalization function, by default nn.LayerNorm
double_skip (bool, optional) – Residual, by default True
sparsity_threshold (float, optional) – Sparsity threshold (softshrink) of spectral features, by default 0.01
hard_thresholding_fraction (float, optional) – Threshold for limiting number of modes used [0,1], by default 1
modulate_filter (bool, optional) – Whether to compute the modulation for the FFT filter
modulate_mlp (bool, optional) – Whether to compute the modulation for the MLP
scale_shift_mode (["complex", "real"]) – If ‘complex’ (default), compute the scale-shift operation using complex operations. If ‘real’, use real operations.
- forward(x: Tensor, mod_embed: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.afno.modafno.MetaData(name: str = 'ModAFNO', jit: bool = False, cuda_graphs: bool = True, amp: bool = True, amp_cpu: bool = None, amp_gpu: bool = None, torch_fx: bool = False, bf16: bool = False, onnx: bool = False, onnx_gpu: bool = True, onnx_cpu: bool = False, onnx_runtime: bool = True, trt: bool = False, var_dim: int = 1, func_torch: bool = False, auto_grad: bool = False)[source]
Bases:
ModelMetaData
- class modulus.models.afno.modafno.ModAFNO(*args, **kwargs)[source]
Bases:
Module
Modulated Adaptive Fourier neural operator (ModAFNO) model.
- Parameters
inp_shape (List[int]) – Input image dimensions [height, width]
in_channels (int, optional) – Number of input channels
out_channels (int, optional) – Number of output channels
embed_model (dict, optional) – Dictionary of arguments to pass to the ModEmbedNet embedding model
patch_size (List[int], optional) – Size of image patches, by default [16, 16]
embed_dim (int, optional) – Embedded channel size, by default 256
mod_dim (int) – Modululation input dimensionality
modulate_filter (bool, optional) – Whether to compute the modulation for the FFT filter, by default True
modulate_mlp (bool, optional) – Whether to compute the modulation for the MLP, by default True
scale_shift_mode (["complex", "real"]) – If ‘complex’ (default), compute the scale-shift operation using complex operations. If ‘real’, use real operations.
depth (int, optional) – Number of AFNO layers, by default 4
mlp_ratio (float, optional) – Ratio of layer MLP latent variable size to input feature size, by default 4.0
drop_rate (float, optional) – Drop out rate in layer MLPs, by default 0.0
num_blocks (int, optional) – Number of blocks in the block-diag frequency weight matrices, by default 16
sparsity_threshold (float, optional) – Sparsity threshold (softshrink) of spectral features, by default 0.01
hard_thresholding_fraction (float, optional) – Threshold for limiting number of modes used [0,1], by default 1
below. (The default settings correspond to the implementation in the paper cited) –
Example
>>> import torch >>> from modulus.models.afno import ModAFNO >>> model = ModAFNO( ... inp_shape=[32, 32], ... in_channels=2, ... out_channels=1, ... patch_size=(8, 8), ... embed_dim=16, ... depth=2, ... num_blocks=2, ... ) >>> input = torch.randn(32, 2, 32, 32) #(N, C, H, W) >>> time = torch.full((32, 1), 0.5) >>> output = model(input, time) >>> output.size() torch.Size([32, 1, 32, 32])
NoteReference: Leinonen et al. “Modulated Adaptive Fourier Neural Operators for Temporal Interpolation of Weather Forecasts.” arXiv preprint arXiv:TODO (2024).
- forward(x: Tensor, mod: Tensor) → Tensor[source]
The full ModAFNO model logic.
- forward_features(x: Tensor, mod: Tensor) → Tensor[source]
Forward pass of core ModAFNO
- class modulus.models.afno.modafno.ModAFNO2DLayer(hidden_size: int, mod_features: int, num_blocks: int = 8, sparsity_threshold: float = 0.01, hard_thresholding_fraction: float = 1, hidden_size_factor: int = 1, scale_shift_kwargs: Optional[dict] = None, scale_shift_mode: Literal['complex', 'real'] = 'complex')[source]
Bases:
AFNO2DLayer
AFNO spectral convolution layer
- Parameters
hidden_size (int) – Feature dimensionality
mod_features (int) – Number of modulation features
num_blocks (int, optional) – Number of blocks used in the block diagonal weight matrix, by default 8
sparsity_threshold (float, optional) – Sparsity threshold (softshrink) of spectral features, by default 0.01
hard_thresholding_fraction (float, optional) – Threshold for limiting number of modes used [0,1], by default 1
hidden_size_factor (int, optional) – Factor to increase spectral features by after weight multiplication, by default 1
scale_shift_kwargs (dict, optional) – Options to the MLP that computes the scale-shift parameters
scale_shift_mode (["complex", "real"]) – If ‘complex’ (default), compute the scale-shift operation using complex operations. If ‘real’, use real operations.
- forward(x: Tensor, mod_embed: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.afno.modafno.ModAFNOMlp(in_features: int, latent_features: int, out_features: int, mod_features: int, activation_fn: Module = GELU(approximate='none'), drop: float = 0.0, scale_shift_kwargs: Optional[dict] = None)[source]
Bases:
AFNOMlp
Modulated MLP used inside ModAFNO
- Parameters
in_features (int) – Input feature size
latent_features (int) – Latent feature size
out_features (int) – Output feature size
activation_fn (nn.Module, optional) – Activation function, by default nn.GELU
drop (float, optional) – Drop out rate, by default 0.0
scale_shift_kwargs (dict, optional) – Options to the MLP that computes the scale-shift parameters
- forward(x: Tensor, mod_embed: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.afno.modafno.ScaleShiftMlp(in_features: int, out_features: int, hidden_features: ~typing.Optional[int] = None, hidden_layers: int = 0, activation_fn: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.GELU'>)[source]
Bases:
Module
MLP used to compute the scale and shift parameters of the ModAFNO block
- Parameters
in_features (int) – Input feature size
out_features (int) – Output feature size
hidden_features (int, optional) – Hidden feature size, defaults to 2 * out_features
hidden_layers (int, optional) – Number of hidden layers, defaults to 0
activation_fn (nn.Module, optional) – Activation function, by default nn.GELU
- forward(x: Tensor)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Graph Neural Networks
- class modulus.models.meshgraphnet.meshgraphnet.MeshGraphNet(*args, **kwargs)[source]
Bases:
Module
MeshGraphNet network architecture
- Parameters
input_dim_nodes (int) – Number of node features
input_dim_edges (int) – Number of edge features
output_dim (int) – Number of outputs
processor_size (int, optional) – Number of message passing blocks, by default 15
mlp_activation_fn (Union[str, List[str]], optional) – Activation function to use, by default ‘relu’
num_layers_node_processor (int, optional) – Number of MLP layers for processing nodes in each message passing block, by default 2
num_layers_edge_processor (int, optional) – Number of MLP layers for processing edge features in each message passing block, by default 2
hidden_dim_processor (int, optional) – Hidden layer size for the message passing blocks, by default 128
hidden_dim_node_encoder (int, optional) – Hidden layer size for the node feature encoder, by default 128
num_layers_node_encoder (Union[int, None], optional) – Number of MLP layers for the node feature encoder, by default 2. If None is provided, the MLP will collapse to a Identity function, i.e. no node encoder
hidden_dim_edge_encoder (int, optional) – Hidden layer size for the edge feature encoder, by default 128
num_layers_edge_encoder (Union[int, None], optional) – Number of MLP layers for the edge feature encoder, by default 2. If None is provided, the MLP will collapse to a Identity function, i.e. no edge encoder
hidden_dim_node_decoder (int, optional) – Hidden layer size for the node feature decoder, by default 128
num_layers_node_decoder (Union[int, None], optional) – Number of MLP layers for the node feature decoder, by default 2. If None is provided, the MLP will collapse to a Identity function, i.e. no decoder
aggregation (str, optional) – Message aggregation type, by default “sum”
do_conat_trick (: bool, default=False) – Whether to replace concat+MLP with MLP+idx+sum
num_processor_checkpoint_segments (int, optional) – Number of processor segments for gradient checkpointing, by default 0 (checkpointing disabled)
Example
>>> model = modulus.models.meshgraphnet.MeshGraphNet( ... input_dim_nodes=4, ... input_dim_edges=3, ... output_dim=2, ... ) >>> graph = dgl.rand_graph(10, 5) >>> node_features = torch.randn(10, 4) >>> edge_features = torch.randn(5, 3) >>> output = model(node_features, edge_features, graph) >>> output.size() torch.Size([10, 2])
NoteReference: Pfaff, Tobias, et al. “Learning mesh-based simulation with graph networks.” arXiv preprint arXiv:2010.03409 (2020).
- forward(node_features: Tensor, edge_features: Tensor, graph: Union[DGLGraph, List[DGLGraph], CuGraphCSC], **kwargs) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.meshgraphnet.meshgraphnet.MeshGraphNetProcessor(processor_size: int = 15, input_dim_node: int = 128, input_dim_edge: int = 128, num_layers_node: int = 2, num_layers_edge: int = 2, aggregation: str = 'sum', norm_type: str = 'LayerNorm', activation_fn: Module = ReLU(), do_concat_trick: bool = False, num_processor_checkpoint_segments: int = 0)[source]
Bases:
Module
MeshGraphNet processor block
- forward(node_features: Tensor, edge_features: Tensor, graph: Union[DGLGraph, List[DGLGraph], CuGraphCSC]) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- run_function(segment_start: int, segment_end: int) → Callable[[Tensor, Tensor, Union[DGLGraph, List[DGLGraph]]], Tuple[Tensor, Tensor]][source]
Custom forward for gradient checkpointing
- Parameters
segment_start (int) – Layer index as start of the segment
segment_end (int) – Layer index as end of the segment
- Returns
Custom forward function
- Return type
Callable
- set_checkpoint_segments(checkpoint_segments: int)[source]
Set the number of checkpoint segments
- Parameters
checkpoint_segments (int) – number of checkpoint segments
- Raises
ValueError – if the number of processor layers is not a multiple of the number of checkpoint segments
- class modulus.models.meshgraphnet.meshgraphnet.MetaData(name: str = 'MeshGraphNet', jit: bool = False, cuda_graphs: bool = False, amp: bool = False, amp_cpu: bool = False, amp_gpu: bool = True, torch_fx: bool = False, bf16: bool = False, onnx: bool = False, onnx_gpu: bool = None, onnx_cpu: bool = None, onnx_runtime: bool = False, trt: bool = False, var_dim: int = -1, func_torch: bool = True, auto_grad: bool = True)[source]
Bases:
ModelMetaData
- class modulus.models.mesh_reduced.mesh_reduced.Mesh_Reduced(input_dim_nodes: int, input_dim_edges: int, output_decode_dim: int, output_encode_dim: int = 3, processor_size: int = 15, num_layers_node_processor: int = 2, num_layers_edge_processor: int = 2, hidden_dim_processor: int = 128, hidden_dim_node_encoder: int = 128, num_layers_node_encoder: int = 2, hidden_dim_edge_encoder: int = 128, num_layers_edge_encoder: int = 2, hidden_dim_node_decoder: int = 128, num_layers_node_decoder: int = 2, k: int = 3, aggregation: str = 'mean')[source]
Bases:
Module
PbGMR-GMUS architecture :param input_dim_nodes: Number of node features :type input_dim_nodes: int :param input_dim_edges: Number of edge features :type input_dim_edges: int :param output_decode_dim: Number of decoding outputs (per node) :type output_decode_dim: int :param output_encode_dim: Number of encoding outputs (per pivotal postion), by default 3 :type output_encode_dim: int, optional :param processor_size: Number of message passing blocks, by default 15 :type processor_size: int, optional :param num_layers_node_processor: Number of MLP layers for processing nodes in each message passing block, by default 2 :type num_layers_node_processor: int, optional :param num_layers_edge_processor: Number of MLP layers for processing edge features in each message passing block, by default 2 :type num_layers_edge_processor: int, optional :param hidden_dim_processor: Hidden layer size for the message passing blocks, by default 128 :type hidden_dim_processor: int, optional :param hidden_dim_node_encoder: Hidden layer size for the node feature encoder, by default 128 :type hidden_dim_node_encoder: int, optional :param num_layers_node_encoder: Number of MLP layers for the node feature encoder, by default 2 :type num_layers_node_encoder: int, optional :param hidden_dim_edge_encoder: Hidden layer size for the edge feature encoder, by default 128 :type hidden_dim_edge_encoder: int, optional :param num_layers_edge_encoder: Number of MLP layers for the edge feature encoder, by default 2 :type num_layers_edge_encoder: int, optional :param hidden_dim_node_decoder: Hidden layer size for the node feature decoder, by default 128 :type hidden_dim_node_decoder: int, optional :param num_layers_node_decoder: Number of MLP layers for the node feature decoder, by default 2 :type num_layers_node_decoder: int, optional :param k: Number of nodes considered for per pivotal postion, by default 3 :type k: int, optional :param aggregation: Message aggregation type, by default “mean” :type aggregation: str, optional
NoteReference: Han, Xu, et al. “Predicting physics in mesh-reduced space with temporal attention.” arXiv preprint arXiv:2201.09113 (2022).
- class modulus.models.meshgraphnet.bsms_mgn.BiStrideMeshGraphNet(*args, **kwargs)[source]
Bases:
MeshGraphNet
Bi-stride MeshGraphNet network architecture
- Parameters
input_dim_nodes (int) – Number of node features
input_dim_edges (int) – Number of edge features
output_dim (int) – Number of outputs
processor_size (int, optional) – Number of message passing blocks, by default 15
mlp_activation_fn (Union[str, List[str]], optional) – Activation function to use, by default ‘relu’
num_layers_node_processor (int, optional) – Number of MLP layers for processing nodes in each message passing block, by default 2
num_layers_edge_processor (int, optional) – Number of MLP layers for processing edge features in each message passing block, by default 2
hidden_dim_processor (int, optional) – Hidden layer size for the message passing blocks, by default 128
hidden_dim_node_encoder (int, optional) – Hidden layer size for the node feature encoder, by default 128
num_layers_node_encoder (Union[int, None], optional) – Number of MLP layers for the node feature encoder, by default 2. If None is provided, the MLP will collapse to a Identity function, i.e. no node encoder
hidden_dim_edge_encoder (int, optional) – Hidden layer size for the edge feature encoder, by default 128
num_layers_edge_encoder (Union[int, None], optional) – Number of MLP layers for the edge feature encoder, by default 2. If None is provided, the MLP will collapse to a Identity function, i.e. no edge encoder
hidden_dim_node_decoder (int, optional) – Hidden layer size for the node feature decoder, by default 128
num_layers_node_decoder (Union[int, None], optional) – Number of MLP layers for the node feature decoder, by default 2. If None is provided, the MLP will collapse to a Identity function, i.e. no decoder
aggregation (str, optional) – Message aggregation type, by default “sum”
do_conat_trick (: bool, default=False) – Whether to replace concat+MLP with MLP+idx+sum
num_processor_checkpoint_segments (int, optional) – Number of processor segments for gradient checkpointing, by default 0 (checkpointing disabled). The number of segments should be a factor of 2 * processor_size, for example, if processor_size is 15, then num_processor_checkpoint_segments can be 10 since it’s a factor of 15 * 2 = 30. It is recommended to start with a smaller number of segments until the model fits into memory since each segment will affect model training speed.
- forward(node_features: Tensor, edge_features: Tensor, graph: DGLGraph, ms_edges: Iterable[Tensor] = (), ms_ids: Iterable[Tensor] = (), **kwargs) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.meshgraphnet.bsms_mgn.MetaData(name: str = 'BiStrideMeshGraphNet', jit: bool = False, cuda_graphs: bool = False, amp: bool = False, amp_cpu: bool = False, amp_gpu: bool = True, torch_fx: bool = False, bf16: bool = False, onnx: bool = False, onnx_gpu: bool = None, onnx_cpu: bool = None, onnx_runtime: bool = False, trt: bool = False, var_dim: int = -1, func_torch: bool = True, auto_grad: bool = True)[source]
Bases:
ModelMetaData
Convolutional Networks
- class modulus.models.pix2pix.pix2pix.MetaData(name: str = 'Pix2Pix', jit: bool = True, cuda_graphs: bool = True, amp: bool = False, amp_cpu: bool = False, amp_gpu: bool = True, torch_fx: bool = False, bf16: bool = False, onnx: bool = True, onnx_gpu: bool = None, onnx_cpu: bool = None, onnx_runtime: bool = False, trt: bool = False, var_dim: int = 1, func_torch: bool = True, auto_grad: bool = True)[source]
Bases:
ModelMetaData
- class modulus.models.pix2pix.pix2pix.Pix2Pix(*args, **kwargs)[source]
Bases:
Module
Convolutional encoder-decoder based on pix2pix generator models.
NoteThe pix2pix architecture supports options for 1D, 2D and 3D fields which can be constroled using the dimension parameter.
- Parameters
in_channels (int) – Number of input channels
out_channels (Union[int, Any], optional) – Number of output channels
dimension (int) – Model dimensionality (supports 1, 2, 3).
conv_layer_size (int, optional) – Latent channel size after first convolution, by default 64
n_downsampling (int, optional) – Number of downsampling blocks, by default 3
n_upsampling (int, optional) – Number of upsampling blocks, by default 3
n_blocks (int, optional) – Number of residual blocks in middle of model, by default 3
activation_fn (Any, optional) – Activation function, by default “relu”
batch_norm (bool, optional) – Batch normalization, by default False
padding_type (str, optional) – Padding type (‘reflect’, ‘replicate’ or ‘zero’), by default “reflect”
Example
>>> #2D convolutional encoder decoder >>> model = modulus.models.pix2pix.Pix2Pix( ... in_channels=1, ... out_channels=2, ... dimension=2, ... conv_layer_size=4) >>> input = torch.randn(4, 1, 32, 32) #(N, C, H, W) >>> output = model(input) >>> output.size() torch.Size([4, 2, 32, 32])
NoteReference: Isola, Phillip, et al. “Image-To-Image translation with conditional adversarial networks” Conference on Computer Vision and Pattern Recognition, 2017. https://arxiv.org/abs/1611.07004
Reference: Wang, Ting-Chun, et al. “High-Resolution image synthesis and semantic manipulation with conditional GANs” Conference on Computer Vision and Pattern Recognition, 2018. https://arxiv.org/abs/1711.11585
NoteBased on the implementation: https://github.com/NVIDIA/pix2pixHD
- forward(input: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.pix2pix.pix2pix.ResnetBlock(dimension: int, channels: int, padding_type: str = 'reflect', activation: Module = ReLU(), use_batch_norm: bool = False, use_dropout: bool = False)[source]
Bases:
Module
A simple ResNet block
- Parameters
dimension (int) – Model dimensionality (supports 1, 2, 3).
channels (int) – Number of feature channels
padding_type (str, optional) – Padding type (‘reflect’, ‘replicate’ or ‘zero’), by default “reflect”
activation (nn.Module, optional) – Activation function, by default nn.ReLU()
use_batch_norm (bool, optional) – Batch normalization, by default False
- forward(x: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.srrn.super_res_net.ConvolutionalBlock3d(in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, batch_norm: bool = False, activation_fn: Module = Identity())[source]
Bases:
Module
3D convolutional block
- Parameters
in_channels (int) – Input channels
out_channels (int) – Output channels
kernel_size (int) – Kernel size
stride (int, optional) – Convolutional stride, by default 1
batch_norm (bool, optional) – Use batchnorm, by default False
- forward(input: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.srrn.super_res_net.MetaData(name: str = 'SuperResolution', jit: bool = True, cuda_graphs: bool = False, amp: bool = False, amp_cpu: bool = False, amp_gpu: bool = False, torch_fx: bool = False, bf16: bool = False, onnx: bool = True, onnx_gpu: bool = None, onnx_cpu: bool = None, onnx_runtime: bool = False, trt: bool = False, var_dim: int = 1, func_torch: bool = True, auto_grad: bool = True)[source]
Bases:
ModelMetaData
- class modulus.models.srrn.super_res_net.PixelShuffle3d(scale: int)[source]
Bases:
Module
3D pixel-shuffle operation
- Parameters
scale (int) – Factor to downscale channel count by
Note- forward(input: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.srrn.super_res_net.ResidualConvBlock3d(n_layers: int = 1, kernel_size: int = 3, conv_layer_size: int = 64, activation_fn: Module = Identity())[source]
Bases:
Module
3D ResNet block
- Parameters
n_layers (int, optional) – Number of convolutional layers, by default 1
kernel_size (int, optional) – Kernel size, by default 3
conv_layer_size (int, optional) – Latent channel size, by default 64
activation_fn (nn.Module, optional) – Activation function, by default nn.Identity()
- forward(input: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.srrn.super_res_net.SRResNet(*args, **kwargs)[source]
Bases:
Module
3D convolutional super-resolution network
- Parameters
in_channels (int) – Number of input channels
out_channels (int) – Number of outout channels
large_kernel_size (int, optional) – convolutional kernel size for first and last convolution, by default 7
small_kernel_size (int, optional) – convolutional kernel size for internal convolutions, by default 3
conv_layer_size (int, optional) – Latent channel size, by default 32
n_resid_blocks (int, optional) – Number of residual blocks before , by default 8
scaling_factor (int, optional) – Scaling factor to increase the output feature size compared to the input (2, 4, or 8), by default 8
activation_fn (Any, optional) – Activation function, by default “prelu”
Example
>>> #3D convolutional encoder decoder >>> model = modulus.models.srrn.SRResNet( ... in_channels=1, ... out_channels=2, ... conv_layer_size=4, ... scaling_factor=2) >>> input = torch.randn(4, 1, 8, 8, 8) #(N, C, D, H, W) >>> output = model(input) >>> output.size() torch.Size([4, 2, 16, 16, 16])
NoteBased on the implementation: https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution
- forward(in_vars: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.srrn.super_res_net.SubPixel_ConvolutionalBlock3d(kernel_size: int = 3, conv_layer_size: int = 64, scaling_factor: int = 2)[source]
Bases:
Module
Convolutional block with Pixel Shuffle operation
- Parameters
kernel_size (int, optional) – Kernel size, by default 3
conv_layer_size (int, optional) – Latent channel size, by default 64
scaling_factor (int, optional) – Pixel shuffle scaling factor, by default 2
- forward(input: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Recurrent Neural Networks
- class modulus.models.rnn.rnn_one2many.MetaData(name: str = 'One2ManyRNN', jit: bool = False, cuda_graphs: bool = False, amp: bool = True, amp_cpu: bool = None, amp_gpu: bool = None, torch_fx: bool = True, bf16: bool = False, onnx: bool = False, onnx_gpu: bool = None, onnx_cpu: bool = None, onnx_runtime: bool = False, trt: bool = False, var_dim: int = -1, func_torch: bool = False, auto_grad: bool = False)[source]
Bases:
ModelMetaData
- class modulus.models.rnn.rnn_one2many.One2ManyRNN(*args, **kwargs)[source]
Bases:
Module
A RNN model with encoder/decoder for 2d/3d problems that provides predictions based on single initial condition.
- Parameters
input_channels (int) – Number of channels in the input
dimension (int, optional) – Spatial dimension of the input. Only 2d and 3d are supported, by default 2
nr_latent_channels (int, optional) – Channels for encoding/decoding, by default 512
nr_residual_blocks (int, optional) – Number of residual blocks, by default 2
activation_fn (str, optional) – Activation function to use, by default “relu”
nr_downsamples (int, optional) – Number of downsamples, by default 2
nr_tsteps (int, optional) – Time steps to predict, by default 32
Example
>>> model = modulus.models.rnn.One2ManyRNN( ... input_channels=6, ... dimension=2, ... nr_latent_channels=32, ... activation_fn="relu", ... nr_downsamples=2, ... nr_tsteps=16, ... ) >>> input = invar = torch.randn(4, 6, 1, 16, 16) # [N, C, T, H, W] >>> output = model(input) >>> output.size() torch.Size([4, 6, 16, 16, 16])
- forward(x: Tensor) → Tensor[source]
Forward pass
- Parameters
x (Tensor) – Expects a tensor of size [N, C, 1, H, W] for 2D or [N, C, 1, D, H, W] for 3D Where, N is the batch size, C is the number of channels, 1 is the number of input timesteps and D, H, W are spatial dimensions.
- Returns
Size [N, C, T, H, W] for 2D or [N, C, T, D, H, W] for 3D. Where, T is the number of timesteps being predicted.
- Return type
Tensor
- class modulus.models.rnn.rnn_seq2seq.MetaData(name: str = 'Seq2SeqRNN', jit: bool = False, cuda_graphs: bool = False, amp: bool = True, amp_cpu: bool = None, amp_gpu: bool = None, torch_fx: bool = True, bf16: bool = False, onnx: bool = False, onnx_gpu: bool = None, onnx_cpu: bool = None, onnx_runtime: bool = False, trt: bool = False, var_dim: int = -1, func_torch: bool = False, auto_grad: bool = False)[source]
Bases:
ModelMetaData
- class modulus.models.rnn.rnn_seq2seq.Seq2SeqRNN(*args, **kwargs)[source]
Bases:
Module
A RNN model with encoder/decoder for 2d/3d problems. Given input 0 to t-1, predicts signal t to t + nr_tsteps
- Parameters
input_channels (int) – Number of channels in the input
dimension (int, optional) – Spatial dimension of the input. Only 2d and 3d are supported, by default 2
nr_latent_channels (int, optional) – Channels for encoding/decoding, by default 512
nr_residual_blocks (int, optional) – Number of residual blocks, by default 2
activation_fn (str, optional) – Activation function to use, by default “relu”
nr_downsamples (int, optional) – Number of downsamples, by default 2
nr_tsteps (int, optional) – Time steps to predict, by default 32
Example
>>> model = modulus.models.rnn.Seq2SeqRNN( ... input_channels=6, ... dimension=2, ... nr_latent_channels=32, ... activation_fn="relu", ... nr_downsamples=2, ... nr_tsteps=16, ... ) >>> input = invar = torch.randn(4, 6, 16, 16, 16) # [N, C, T, H, W] >>> output = model(input) >>> output.size() torch.Size([4, 6, 16, 16, 16])
- forward(x: Tensor) → Tensor[source]
Forward pass
- Parameters
x (Tensor) – Expects a tensor of size [N, C, T, H, W] for 2D or [N, C, T, D, H, W] for 3D Where, N is the batch size, C is the number of channels, T is the number of input timesteps and D, H, W are spatial dimensions. Currently, this requires input time steps to be same as predicted time steps.
- Returns
Size [N, C, T, H, W] for 2D or [N, C, T, D, H, W] for 3D. Where, T is the number of timesteps being predicted.
- Return type
Tensor
Weather / Climate Models
- class modulus.models.dlwp.dlwp.DLWP(*args, **kwargs)[source]
Bases:
Module
A Convolutional model for Deep Learning Weather Prediction that works on Cubed-sphere grids.
This model expects the input to be of shape [N, C, 6, Res, Res]
- Parameters
nr_input_channels (int) – Number of channels in the input
nr_output_channels (int) – Number of channels in the output
nr_initial_channels (int) – Number of channels in the initial convolution. This governs the overall channels in the model.
activation_fn (str) – Activation function for the convolutions
depth (int) – Depth for the U-Net
clamp_activation (Tuple of ints, floats or None) – The min and max value used for torch.clamp()
Example
>>> model = modulus.models.dlwp.DLWP( ... nr_input_channels=2, ... nr_output_channels=4, ... ) >>> input = torch.randn(4, 2, 6, 64, 64) # [N, C, F, Res, Res] >>> output = model(input) >>> output.size() torch.Size([4, 4, 6, 64, 64])
Note- Reference: Weyn, Jonathan A., et al. “Sub‐seasonal forecasting with a large ensemble
of deep‐learning weather prediction models.” Journal of Advances in Modeling Earth Systems 13.7 (2021): e2021MS002502.
- forward(cubed_sphere_input)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.dlwp.dlwp.MetaData(name: str = 'DLWP', jit: bool = False, cuda_graphs: bool = True, amp: bool = False, amp_cpu: bool = True, amp_gpu: bool = True, torch_fx: bool = False, bf16: bool = False, onnx: bool = False, onnx_gpu: bool = None, onnx_cpu: bool = None, onnx_runtime: bool = False, trt: bool = False, var_dim: int = 1, func_torch: bool = False, auto_grad: bool = False)[source]
Bases:
ModelMetaData
- class modulus.models.dlwp_healpix.HEALPixRecUNet.HEALPixRecUNet(*args, **kwargs)[source]
Bases:
Module
Deep Learning Weather Prediction (DLWP) recurrent UNet model on the HEALPix mesh.
- forward(inputs: Sequence, output_only_last=False) → Tensor[source]
Forward pass of the HEALPixUnet
- Parameters
inputs (Sequence) – Inputs to the model, of the form [prognostics|TISR|constants] [B, F, T, C, H, W] is the format for prognostics and TISR [F, C, H, W] is the format for constants
output_only_last (bool, optional) – If only the last dimension of the outputs should be returned
- Returns
th.Tensor
- Return type
Predicted outputs
- property integration_steps
Number of integration steps
- reset()[source]
Resets the state of the network
- class modulus.models.dlwp_healpix.HEALPixRecUNet.MetaData(name: str = 'DLWP_HEALPixRec', jit: bool = False, cuda_graphs: bool = True, amp: bool = False, amp_cpu: bool = True, amp_gpu: bool = True, torch_fx: bool = False, bf16: bool = False, onnx: bool = False, onnx_gpu: Optional[bool] = None, onnx_cpu: Optional[bool] = None, onnx_runtime: bool = False, trt: bool = False, var_dim: int = 1, func_torch: bool = False, auto_grad: bool = False)[source]
Bases:
ModelMetaData
Metadata for the DLWP HEALPix Model
- class modulus.models.graphcast.graph_cast_net.GraphCastNet(*args, **kwargs)[source]
Bases:
Module
GraphCast network architecture
- Parameters
multimesh_level (int, optional) – Level of the latent mesh, by default 6
multimesh (bool, optional) – If the latent mesh is a multimesh, by default True If True, the latent mesh includes the nodes corresponding to the specified mesh_level`and incorporates the edges from all mesh levels ranging from level 0 up to and including `mesh_level.
input_res (Tuple[int, int]) – Input resolution of the latitude-longitude grid
input_dim_grid_nodes (int, optional) – Input dimensionality of the grid node features, by default 474
input_dim_mesh_nodes (int, optional) – Input dimensionality of the mesh node features, by default 3
input_dim_edges (int, optional) – Input dimensionality of the edge features, by default 4
output_dim_grid_nodes (int, optional) – Final output dimensionality of the edge features, by default 227
processor_type (str, optional) – The type of processor used in this model. Available options are ‘MessagePassing’, and ‘GraphTransformer’, which correspond to the processors in GraphCast and GenCast, respectively. By default ‘MessagePassing’.
khop_neighbors (int, optional) – Number of khop neighbors used in the GraphTransformer. This option is ignored if ‘MessagePassing’ processor is used. By default 0.
processor_layers (int, optional) – Number of processor layers, by default 16
hidden_layers (int, optional) – Number of hiddel layers, by default 1
hidden_dim (int, optional) – Number of neurons in each hidden layer, by default 512
aggregation (str, optional) – Message passing aggregation method (“sum”, “mean”), by default “sum”
activation_fn (str, optional) – Type of activation function, by default “silu”
norm_type (str, optional) – Normalization type [“TELayerNorm”, “LayerNorm”]. Use “TELayerNorm” for optimal performance. By default “LayerNorm”.
use_cugraphops_encoder (bool, default=False) – Flag to select cugraphops kernels in encoder
use_cugraphops_processor (bool, default=False) – Flag to select cugraphops kernels in the processor
use_cugraphops_decoder (bool, default=False) – Flag to select cugraphops kernels in the decoder
do_conat_trick (: bool, default=False) – Whether to replace concat+MLP with MLP+idx+sum
recompute_activation (bool, optional) – Flag for recomputing activation in backward to save memory, by default False. Currently, only SiLU is supported.
partition_size (int, default=1) – Number of process groups across which graphs are distributed. If equal to 1, the model is run in a normal Single-GPU configuration.
partition_group_name (str, default=None) – Name of process group across which graphs are distributed. If partition_size is set to 1, the model is run in a normal Single-GPU configuration and the specification of a process group is not necessary. If partitition_size > 1, passing no process group name leads to a parallelism across the default process group. Otherwise, the group size of a process group is expected to match partition_size.
use_lat_lon_partitioning (bool, default=False) – flag to specify whether all graphs (grid-to-mesh, mesh, mesh-to-grid) are partitioned based on lat-lon-coordinates of nodes or based on IDs.
expect_partitioned_input (bool, default=False) – Flag indicating whether the model expects the input to be already partitioned. This can be helpful e.g. in multi-step rollouts to avoid aggregating the output just to distribute it in the next step again.
global_features_on_rank_0 (bool, default=False) – Flag indicating whether the model expects the input to be present in its “global” form only on group_rank 0. During the input preparation phase, the model will take care of scattering the input accordingly onto all ranks of the process group across which the graph is partitioned. Note that only either this flag or expect_partitioned_input can be set at a time.
produce_aggregated_output (bool, default=True) – Flag indicating whether the model produces the aggregated output on each rank of the procress group across which the graph is distributed or whether the output is kept distributed. This can be helpful e.g. in multi-step rollouts to avoid aggregating the output just to distribute it in the next step again.
produce_aggregated_output_on_all_ranks (bool, default=True) – Flag indicating - if produce_aggregated_output is True - whether the model produces the aggregated output on each rank of the process group across which the group is distributed or only on group_rank 0. This can be helpful for computing the loss using global targets only on a single rank which can avoid either having to distribute the computation of a loss function.
NoteBased on these papers: - “GraphCast: Learning skillful medium-range global weather forecasting”
- “Forecasting Global Weather with Graph Neural Networks”
- “Learning Mesh-Based Simulation with Graph Networks”
- “MultiScale MeshGraphNets”
- “GenCast: Diffusion-based ensemble forecasting for medium-range weather”
- custom_forward(grid_nfeat: Tensor) → Tensor[source]
GraphCast forward method with support for gradient checkpointing.
- Parameters
grid_nfeat (Tensor) – Node features of the latitude-longitude graph.
- Returns
grid_nfeat_finale – Predicted node features of the latitude-longitude graph.
- Return type
Tensor
- decoder_forward(mesh_efeat_processed: Tensor, mesh_nfeat_processed: Tensor, grid_nfeat_encoded: Tensor) → Tensor[source]
Forward method for the last layer of the processor, the decoder, and the final MLP.
- Parameters
mesh_efeat_processed (Tensor) – Multimesh edge features processed by the processor.
mesh_nfeat_processed (Tensor) – Multi-mesh node features processed by the processor.
grid_nfeat_encoded (Tensor) – The encoded node features for the latitude-longitude grid.
- Returns
grid_nfeat_finale – The final node features for the latitude-longitude grid.
- Return type
Tensor
- encoder_forward(grid_nfeat: Tensor) → Tensor[source]
Forward method for the embedder, encoder, and the first of the processor.
- Parameters
grid_nfeat (Tensor) – Node features for the latitude-longitude grid.
- Returns
mesh_efeat_processed (Tensor) – Processed edge features for the multimesh.
mesh_nfeat_processed (Tensor) – Processed node features for the multimesh.
grid_nfeat_encoded (Tensor) – Encoded node features for the latitude-longitude grid.
- forward(grid_nfeat: Tensor) → Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- prepare_input(invar: Tensor, expect_partitioned_input: bool, global_features_on_rank_0: bool) → Tensor[source]
Prepares the input to the model in the required shape.
- Parameters
invar (Tensor) – Input in the shape [N, C, H, W].
expect_partitioned_input (bool) – flag indicating whether input is partioned according to graph partitioning scheme
global_features_on_rank_0 (bool) – Flag indicating whether input is in its “global” form only on group_rank 0 which requires a scatter operation beforehand. Note that only either this flag or expect_partitioned_input can be set at a time.
- Returns
Reshaped input.
- Return type
Tensor
- prepare_output(outvar: Tensor, produce_aggregated_output: bool, produce_aggregated_output_on_all_ranks: bool = True) → Tensor[source]
Prepares the output of the model in the shape [N, C, H, W].
- Parameters
outvar (Tensor) – Output of the final MLP of the model.
produce_aggregated_output (bool) – flag indicating whether output is gathered onto each rank or kept distributed
produce_aggregated_output_on_all_ranks (bool) – flag indicating whether output is gatherered on each rank or only gathered at group_rank 0, True by default and only valid if produce_aggregated_output is set.
- Returns
The reshaped output of the model.
- Return type
Tensor
- set_checkpoint_decoder(checkpoint_flag: bool)[source]
Sets checkpoint function for the last layer of the processor, the decoder, and the final MLP.
This function returns the appropriate checkpoint function based on the provided checkpoint_flag flag. If checkpoint_flag is True, the function returns the checkpoint function from PyTorch’s torch.utils.checkpoint. Otherwise, it returns an identity function that simply passes the inputs through the given layer.
- Parameters
checkpoint_flag (bool) – Whether to use checkpointing for gradient computation. Checkpointing can reduce memory usage during backpropagation at the cost of increased computation time.
- Returns
The selected checkpoint function to use for gradient computation.
- Return type
Callable
- set_checkpoint_encoder(checkpoint_flag: bool)[source]
Sets checkpoint function for the embedder, encoder, and the first of the processor.
This function returns the appropriate checkpoint function based on the provided checkpoint_flag flag. If checkpoint_flag is True, the function returns the checkpoint function from PyTorch’s torch.utils.checkpoint. Otherwise, it returns an identity function that simply passes the inputs through the given layer.
- Parameters
checkpoint_flag (bool) – Whether to use checkpointing for gradient computation. Checkpointing can reduce memory usage during backpropagation at the cost of increased computation time.
- Returns
The selected checkpoint function to use for gradient computation.
- Return type
Callable
- set_checkpoint_model(checkpoint_flag: bool)[source]
Sets checkpoint function for the entire model.
This function returns the appropriate checkpoint function based on the provided checkpoint_flag flag. If checkpoint_flag is True, the function returns the checkpoint function from PyTorch’s torch.utils.checkpoint. In this case, all the other gradient checkpoitings will be disabled. Otherwise, it returns an identity function that simply passes the inputs through the given layer.
- Parameters
checkpoint_flag (bool) – Whether to use checkpointing for gradient computation. Checkpointing can reduce memory usage during backpropagation at the cost of increased computation time.
- Returns
The selected checkpoint function to use for gradient computation.
- Return type
Callable
- set_checkpoint_processor(checkpoint_segments: int)[source]
Sets checkpoint function for the processor excluding the first and last layers.
This function returns the appropriate checkpoint function based on the provided checkpoint_segments flag. If checkpoint_segments is positive, the function returns the checkpoint function from PyTorch’s torch.utils.checkpoint, with number of checkpointing segments equal to checkpoint_segments. Otherwise, it returns an identity function that simply passes the inputs through the given layer.
- Parameters
checkpoint_segments (int) – Number of checkpointing segments for gradient computation. Checkpointing can reduce memory usage during backpropagation at the cost of increased computation time.
- Returns
The selected checkpoint function to use for gradient computation.
- Return type
Callable
- to(*args: Any, **kwargs: Any) → Self[source]
Moves the object to the specified device, dtype, or format. This method moves the object and its underlying graph and graph features to the specified device, dtype, or format, and returns the updated object.
- Parameters
*args (Any) – Positional arguments to be passed to the torch._C._nn._parse_to function.
**kwargs (Any) – Keyword arguments to be passed to the torch._C._nn._parse_to function.
- Returns
The updated object after moving to the specified device, dtype, or format.
- Return type
- class modulus.models.graphcast.graph_cast_net.MetaData(name: str = 'GraphCastNet', jit: bool = False, cuda_graphs: bool = False, amp: bool = False, amp_cpu: bool = False, amp_gpu: bool = True, torch_fx: bool = False, bf16: bool = True, onnx: bool = False, onnx_gpu: bool = None, onnx_cpu: bool = None, onnx_runtime: bool = False, trt: bool = False, var_dim: int = -1, func_torch: bool = False, auto_grad: bool = False)[source]
Bases:
ModelMetaData
- modulus.models.graphcast.graph_cast_net.get_lat_lon_partition_separators(partition_size: int)[source]
Utility Function to get separation intervals for lat-lon grid for partition_sizes of interest.
- Parameters
partition_size (int) – size of graph partition
- class modulus.models.fengwu.fengwu.Fengwu(*args, **kwargs)[source]
Bases:
Module
FengWu PyTorch impl of: FengWu: Pushing the Skillful Global Medium-range Weather Forecast beyond 10 Days Lead - https://arxiv.org/pdf/2304.02948.pdf
- Parameters
img_size – Image size(Lat, Lon). Default: (721,1440)
pressure_level – Number of pressure_level. Default: 37
embed_dim (int) – Patch embedding dimension. Default: 192
patch_size (tuple[int]) – Patch token size. Default: (4,4)
num_heads (tuple[int]) – Number of attention heads in different layers.
window_size (tuple[int]) – Window size.
- forward(x)[source]
- Parameters
surface (torch.Tensor) – 2D n_lat=721, n_lon=1440, chans=4.
z (torch.Tensor) – 2D n_lat=721, n_lon=1440, chans=37.
r (torch.Tensor) – 2D n_lat=721, n_lon=1440, chans=37.
u (torch.Tensor) – 2D n_lat=721, n_lon=1440, chans=37.
v (torch.Tensor) – 2D n_lat=721, n_lon=1440, chans=37.
t (torch.Tensor) – 2D n_lat=721, n_lon=1440, chans=37.
- prepare_input(surface, z, r, u, v, t)[source]
Prepares the input to the model in the required shape. :param surface: 2D n_lat=721, n_lon=1440, chans=4. :type surface: torch.Tensor :param z: 2D n_lat=721, n_lon=1440, chans=37. :type z: torch.Tensor :param r: 2D n_lat=721, n_lon=1440, chans=37. :type r: torch.Tensor :param u: 2D n_lat=721, n_lon=1440, chans=37. :type u: torch.Tensor :param v: 2D n_lat=721, n_lon=1440, chans=37. :type v: torch.Tensor :param t: 2D n_lat=721, n_lon=1440, chans=37. :type t: torch.Tensor
- class modulus.models.fengwu.fengwu.MetaData(name: str = 'Fengwu', jit: bool = False, cuda_graphs: bool = True, amp: bool = True, amp_cpu: bool = None, amp_gpu: bool = None, torch_fx: bool = False, bf16: bool = False, onnx: bool = False, onnx_gpu: bool = True, onnx_cpu: bool = False, onnx_runtime: bool = True, trt: bool = False, var_dim: int = 1, func_torch: bool = False, auto_grad: bool = False)[source]
Bases:
ModelMetaData
- class modulus.models.pangu.pangu.MetaData(name: str = 'Pangu', jit: bool = False, cuda_graphs: bool = True, amp: bool = True, amp_cpu: bool = None, amp_gpu: bool = None, torch_fx: bool = False, bf16: bool = False, onnx: bool = False, onnx_gpu: bool = True, onnx_cpu: bool = False, onnx_runtime: bool = True, trt: bool = False, var_dim: int = 1, func_torch: bool = False, auto_grad: bool = False)[source]
Bases:
ModelMetaData
- class modulus.models.pangu.pangu.Pangu(*args, **kwargs)[source]
Bases:
Module
Pangu A PyTorch impl of: Pangu-Weather: A 3D High-Resolution Model for Fast and Accurate Global Weather Forecast - https://arxiv.org/abs/2211.02556
- Parameters
img_size (tuple[int]) – Image size [Lat, Lon].
patch_size (tuple[int]) – Patch token size [Lat, Lon].
embed_dim (int) – Patch embedding dimension. Default: 192
num_heads (tuple[int]) – Number of attention heads in different layers.
window_size (tuple[int]) – Window size.
- forward(x)[source]
- Parameters
x (torch.Tensor) – [batch, 4+3+5*13, lat, lon]
- prepare_input(surface, surface_mask, upper_air)[source]
Prepares the input to the model in the required shape. :param surface: 2D n_lat=721, n_lon=1440, chans=4. :type surface: torch.Tensor :param surface_mask: 2D n_lat=721, n_lon=1440, chans=3. :type surface_mask: torch.Tensor :param upper_air: 3D n_pl=13, n_lat=721, n_lon=1440, chans=5. :type upper_air: torch.Tensor
- class modulus.models.swinvrnn.swinvrnn.MetaData(name: str = 'SwinRNN', jit: bool = False, cuda_graphs: bool = True, amp: bool = True, amp_cpu: bool = None, amp_gpu: bool = None, torch_fx: bool = False, bf16: bool = False, onnx: bool = False, onnx_gpu: bool = True, onnx_cpu: bool = False, onnx_runtime: bool = True, trt: bool = False, var_dim: int = 1, func_torch: bool = False, auto_grad: bool = False)[source]
Bases:
ModelMetaData
- class modulus.models.swinvrnn.swinvrnn.SwinRNN(*args, **kwargs)[source]
Bases:
Module
Implementation of SwinRNN https://arxiv.org/abs/2205.13158 :param img_size: Image size [T, Lat, Lon]. :type img_size: Sequence[int], optional :param patch_size: Patch token size [T, Lat, Lon]. :type patch_size: Sequence[int], optional :param in_chans: number of input channels. :type in_chans: int, optional :param out_chans: number of output channels. :type out_chans: int, optional :param embed_dim: number of embed channels. :type embed_dim: int, optional :param num_groups: number of groups to separate the channels into. :type num_groups: Sequence[int] | int, optional :param num_heads: Number of attention heads. :type num_heads: int, optional :param window_size: Local window size. :type window_size: int | tuple[int], optional
- forward(x: Tensor)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Diffusion Model
Model architectures used in the paper “Elucidating the Design Space of Diffusion-Based Generative Models”.
- class modulus.models.diffusion.dhariwal_unet.DhariwalUNet(*args, **kwargs)[source]
Bases:
Module
Reimplementation of the ADM architecture, a U-Net variant, with optional self-attention.
This model supports conditional and unconditional setups, as well as several options for various internal architectural choices such as encoder and decoder type, embedding type, etc., making it flexible and adaptable to different tasks and configurations.
- Parameters
img_resolution (int) – The resolution of the input/output image.
in_channels (int) – Number of channels in the input image.
out_channels (int) – Number of channels in the output image.
label_dim (int, optional) – Number of class labels; 0 indicates an unconditional model. By default 0.
augment_dim (int, optional) – Dimensionality of augmentation labels; 0 means no augmentation. By default 0.
model_channels (int, optional) – Base multiplier for the number of channels across the network, by default 192.
channel_mult (List[int], optional) – Per-resolution multipliers for the number of channels. By default [1,2,3,4].
channel_mult_emb (int, optional) – Multiplier for the dimensionality of the embedding vector. By default 4.
num_blocks (int, optional) – Number of residual blocks per resolution. By default 3.
attn_resolutions (List[int], optional) – Resolutions at which self-attention layers are applied. By default [32, 16, 8].
dropout (float, optional) – Dropout probability applied to intermediate activations. By default 0.10.
label_dropout (float, optional) – Dropout probability of class labels for classifier-free guidance. By default 0.0.
Reference (Dhariwal, P. and Nichol, A., 2021. Diffusion models beat gans on image) –
---------- –
Reference –
systems (synthesis. Advances in neural information processing) –
34 –
pp.8780-8794. –
NoteEquivalent to the original implementation by Dhariwal and Nichol, available at https://github.com/openai/guided-diffusion
Example
>>> model = DhariwalUNet(img_resolution=16, in_channels=2, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) >>> input_image = torch.ones([1, 2, 16, 16]) >>> output_image = model(input_image, noise_labels, class_labels)
- forward(x, noise_labels, class_labels, augment_labels=None)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.diffusion.dhariwal_unet.MetaData(name: str = 'DhariwalUNet', jit: bool = False, cuda_graphs: bool = False, amp: bool = False, amp_cpu: bool = False, amp_gpu: bool = True, torch_fx: bool = False, bf16: bool = True, onnx: bool = False, onnx_gpu: bool = None, onnx_cpu: bool = None, onnx_runtime: bool = False, trt: bool = False, var_dim: int = -1, func_torch: bool = False, auto_grad: bool = False)[source]
Bases:
ModelMetaData
Model architectures used in the paper “Elucidating the Design Space of Diffusion-Based Generative Models”.
- class modulus.models.diffusion.song_unet.MetaData(name: str = 'SongUNet', jit: bool = False, cuda_graphs: bool = False, amp: bool = False, amp_cpu: bool = False, amp_gpu: bool = True, torch_fx: bool = False, bf16: bool = True, onnx: bool = False, onnx_gpu: bool = None, onnx_cpu: bool = None, onnx_runtime: bool = False, trt: bool = False, var_dim: int = -1, func_torch: bool = False, auto_grad: bool = False)[source]
Bases:
ModelMetaData
- class modulus.models.diffusion.song_unet.SongUNet(*args, **kwargs)[source]
Bases:
Module
Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with optional self-attention,embeddings, and encoder-decoder components.
This model supports conditional and unconditional setups, as well as several options for various internal architectural choices such as encoder and decoder type, embedding type, etc., making it flexible and adaptable to different tasks and configurations.
- Parameters
img_resolution (Union[List[int], int]) – The resolution of the input/output image, 1 value represents a square image.
in_channels (int) – Number of channels in the input image.
out_channels (int) – Number of channels in the output image.
label_dim (int, optional) – Number of class labels; 0 indicates an unconditional model. By default 0.
augment_dim (int, optional) – Dimensionality of augmentation labels; 0 means no augmentation. By default 0.
model_channels (int, optional) – Base multiplier for the number of channels across the network, by default 128.
channel_mult (List[int], optional) – Per-resolution multipliers for the number of channels. By default [1,2,2,2].
channel_mult_emb (int, optional) – Multiplier for the dimensionality of the embedding vector. By default 4.
num_blocks (int, optional) – Number of residual blocks per resolution. By default 4.
attn_resolutions (List[int], optional) – Resolutions at which self-attention layers are applied. By default [16].
dropout (float, optional) – Dropout probability applied to intermediate activations. By default 0.10.
label_dropout (float, optional) – Dropout probability of class labels for classifier-free guidance. By default 0.0.
embedding_type (str, optional) – Timestep embedding type: ‘positional’ for DDPM++, ‘fourier’ for NCSN++. By default ‘positional’.
channel_mult_noise (int, optional) – Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1.
encoder_type (str, optional) – Encoder architecture: ‘standard’ for DDPM++, ‘residual’ for NCSN++. By default ‘standard’.
decoder_type (str, optional) – Decoder architecture: ‘standard’ for both DDPM++ and NCSN++. By default ‘standard’.
resample_filter (List[int], optional (default=[1,1])) – Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++.
checkpoint_level (int, optional (default=0)) – How many layers should use gradient checkpointing, 0 is None
Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and Poole, B., 2020. Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456.
NoteEquivalent to the original implementation by Song et al., available at https://github.com/yang-song/score_sde_pytorch
Example
>>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) >>> input_image = torch.ones([1, 2, 16, 16]) >>> output_image = model(input_image, noise_labels, class_labels) >>> output_image.shape torch.Size([1, 2, 16, 16])
- forward(x, noise_labels, class_labels, augment_labels=None)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.diffusion.song_unet.SongUNetPosEmbd(*args, **kwargs)[source]
Bases:
SongUNet
Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with optional self-attention,embeddings, and encoder-decoder components.
This model supports conditional and unconditional setups, as well as several options for various internal architectural choices such as encoder and decoder type, embedding type, etc., making it flexible and adaptable to different tasks and configurations.
- Parameters
img_resolution (Union[List[int], int]) – The resolution of the input/output image, 1 value represents a square image.
in_channels (int) – Number of channels in the input image.
out_channels (int) – Number of channels in the output image.
label_dim (int, optional) – Number of class labels; 0 indicates an unconditional model. By default 0.
augment_dim (int, optional) – Dimensionality of augmentation labels; 0 means no augmentation. By default 0.
model_channels (int, optional) – Base multiplier for the number of channels across the network, by default 128.
channel_mult (List[int], optional) – Per-resolution multipliers for the number of channels. By default [1,2,2,2].
channel_mult_emb (int, optional) – Multiplier for the dimensionality of the embedding vector. By default 4.
num_blocks (int, optional) – Number of residual blocks per resolution. By default 4.
attn_resolutions (List[int], optional) – Resolutions at which self-attention layers are applied. By default [16].
dropout (float, optional) – Dropout probability applied to intermediate activations. By default 0.13.
label_dropout (float, optional) – Dropout probability of class labels for classifier-free guidance. By default 0.0.
embedding_type (str, optional) – Timestep embedding type: ‘positional’ for DDPM++, ‘fourier’ for NCSN++. By default ‘positional’.
channel_mult_noise (int, optional) – Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1.
encoder_type (str, optional) – Encoder architecture: ‘standard’ for DDPM++, ‘residual’ for NCSN++. By default ‘standard’.
decoder_type (str, optional) – Decoder architecture: ‘standard’ for both DDPM++ and NCSN++. By default ‘standard’.
resample_filter (List[int], optional (default=[1,1])) – Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++.
Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and Poole, B., 2020. Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456.
NoteEquivalent to the original implementation by Song et al., available at https://github.com/yang-song/score_sde_pytorch
Example
>>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) >>> input_image = torch.ones([1, 2, 16, 16]) >>> output_image = model(input_image, noise_labels, class_labels) >>> output_image.shape torch.Size([1, 2, 16, 16])
- forward(x, noise_labels, class_labels, global_index=None, augment_labels=None)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.diffusion.song_unet.SongUNetPosLtEmbd(*args, **kwargs)[source]
Bases:
SongUNet
This model is adapated from SongUNetPosEmbd, with the incoporatation of lead-time aware embedding for the GEFS-HRRR model. The lead-time embedding is activated by setting the lead_time_channels and lead_time_steps parameters.
- Parameters
img_resolution (Union[List[int], int]) – The resolution of the input/output image, 1 value represents a square image.
in_channels (int) – Number of channels in the input image.
out_channels (int) – Number of channels in the output image.
label_dim (int, optional) – Number of class labels; 0 indicates an unconditional model. By default 0.
augment_dim (int, optional) – Dimensionality of augmentation labels; 0 means no augmentation. By default 0.
model_channels (int, optional) – Base multiplier for the number of channels across the network, by default 128.
channel_mult (List[int], optional) – Per-resolution multipliers for the number of channels. By default [1,2,2,2].
channel_mult_emb (int, optional) – Multiplier for the dimensionality of the embedding vector. By default 4.
num_blocks (int, optional) – Number of residual blocks per resolution. By default 4.
attn_resolutions (List[int], optional) – Resolutions at which self-attention layers are applied. By default [16].
dropout (float, optional) – Dropout probability applied to intermediate activations. By default 0.13.
label_dropout (float, optional) – Dropout probability of class labels for classifier-free guidance. By default 0.0.
embedding_type (str, optional) – Timestep embedding type: ‘positional’ for DDPM++, ‘fourier’ for NCSN++. By default ‘positional’.
channel_mult_noise (int, optional) – Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1.
encoder_type (str, optional) – Encoder architecture: ‘standard’ for DDPM++, ‘residual’ for NCSN++. By default ‘standard’.
decoder_type (str, optional) – Decoder architecture: ‘standard’ for both DDPM++ and NCSN++. By default ‘standard’.
resample_filter (List[int], optional (default=[1,1])) – Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++.
lead_time_channels (int, optional) – Length of lead time embedding vector
lead_time_steps (int, optional) – Total number of lead times
Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and Poole, B., 2020. Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456.
NoteEquivalent to the original implementation by Song et al., available at https://github.com/yang-song/score_sde_pytorch
Example
>>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) >>> input_image = torch.ones([1, 2, 16, 16]) >>> output_image = model(input_image, noise_labels, class_labels) >>> output_image.shape torch.Size([1, 2, 16, 16])
- forward(x, noise_labels, class_labels, lead_time_label=None, global_index=None, augment_labels=None)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class modulus.models.diffusion.unet.MetaData(name: str = 'UNet', jit: bool = False, cuda_graphs: bool = False, amp: bool = False, amp_cpu: bool = False, amp_gpu: bool = True, torch_fx: bool = False, bf16: bool = True, onnx: bool = False, onnx_gpu: bool = None, onnx_cpu: bool = None, onnx_runtime: bool = False, trt: bool = False, var_dim: int = -1, func_torch: bool = False, auto_grad: bool = False)[source]
Bases:
ModelMetaData
- class modulus.models.diffusion.unet.UNet(*args, **kwargs)[source]
Bases:
Module
U-Net Wrapper for CorrDiff.
- Parameters
img_resolution (int) – The resolution of the input/output image.
img_channels (int) – Number of color channels.
img_in_channels (int) – Number of input color channels.
img_out_channels (int) – Number of output color channels.
use_fp16 (bool, optional) – Execute the underlying model at FP16 precision?, by default False.
sigma_min (float, optional) – Minimum supported noise level, by default 0.
sigma_max (float, optional) – Maximum supported noise level, by default float(‘inf’).
sigma_data (float, optional) – Expected standard deviation of the training data, by default 0.5.
model_type (str, optional) – Class name of the underlying model, by default ‘DhariwalUNet’.
**model_kwargs (dict) – Keyword arguments for the underlying model.
Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. arXiv preprint arXiv:2309.15214.
- forward(x, img_lr, sigma, force_fp32=False, **model_kwargs)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
NoteAlthough the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- round_sigma(sigma)[source]
Convert a given sigma value(s) to a tensor representation.
- Parameters
sigma (Union[float list, torch.Tensor]) – The sigma value(s) to convert.
- Returns
The tensor representation of the provided sigma value(s).
- Return type
torch.Tensor