# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import modulus
import modulus.models.layers.fft as fft
from functools import partial
from typing import Tuple, Any
from dataclasses import dataclass
from ..meta import ModelMetaData
from ..module import Module
Tensor = torch.Tensor
[docs]class AFNOMlp(nn.Module):
"""Fully-connected Multi-layer perception used inside AFNO
Parameters
----------
in_features : int
Input feature size
latent_features : int
Latent feature size
out_features : int
Output feature size
activation_fn : nn.Module, optional
Activation function, by default nn.GELU
drop : float, optional
Drop out rate, by default 0.0
"""
def __init__(
self,
in_features: int,
latent_features: int,
out_features: int,
activation_fn: nn.Module = nn.GELU(),
drop: float = 0.0,
):
super().__init__()
self.fc1 = nn.Linear(in_features, latent_features)
self.act = activation_fn
self.fc2 = nn.Linear(latent_features, out_features)
self.drop = nn.Dropout(drop)
[docs] def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x [docs]class AFNO2DLayer(nn.Module):
"""AFNO spectral convolution layer
Parameters
----------
hidden_size : int
Feature dimensionality
num_blocks : int, optional
Number of blocks used in the block diagonal weight matrix, by default 8
sparsity_threshold : float, optional
Sparsity threshold (softshrink) of spectral features, by default 0.01
hard_thresholding_fraction : float, optional
Threshold for limiting number of modes used [0,1], by default 1
hidden_size_factor : int, optional
Factor to increase spectral features by after weight multiplication, by default 1
"""
def __init__(
self,
hidden_size: int,
num_blocks: int = 8,
sparsity_threshold: float = 0.01,
hard_thresholding_fraction: float = 1,
hidden_size_factor: int = 1,
):
super().__init__()
assert (
hidden_size % num_blocks == 0
), f"hidden_size {hidden_size} should be divisible by num_blocks {num_blocks}"
self.hidden_size = hidden_size
self.sparsity_threshold = sparsity_threshold
self.num_blocks = num_blocks
self.block_size = self.hidden_size // self.num_blocks
self.hard_thresholding_fraction = hard_thresholding_fraction
self.hidden_size_factor = hidden_size_factor
self.scale = 0.02
self.w1 = nn.Parameter(
self.scale
* torch.randn(
2,
self.num_blocks,
self.block_size,
self.block_size * self.hidden_size_factor,
)
)
self.b1 = nn.Parameter(
self.scale
* torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)
)
self.w2 = nn.Parameter(
self.scale
* torch.randn(
2,
self.num_blocks,
self.block_size * self.hidden_size_factor,
self.block_size,
)
)
self.b2 = nn.Parameter(
self.scale * torch.randn(2, self.num_blocks, self.block_size)
)
[docs] def forward(self, x: Tensor) -> Tensor:
bias = x
dtype = x.dtype
x = x.float()
B, H, W, C = x.shape
# Using ONNX friendly FFT functions
x = fft.rfft2(x, dim=(1, 2), norm="ortho")
x_real, x_imag = fft.real(x), fft.imag(x)
x_real = x_real.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size)
x_imag = x_imag.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size)
o1_real = torch.zeros(
[
B,
H,
W // 2 + 1,
self.num_blocks,
self.block_size * self.hidden_size_factor,
],
device=x.device,
)
o1_imag = torch.zeros(
[
B,
H,
W // 2 + 1,
self.num_blocks,
self.block_size * self.hidden_size_factor,
],
device=x.device,
)
o2 = torch.zeros(x_real.shape + (2,), device=x.device)
total_modes = H // 2 + 1
kept_modes = int(total_modes * self.hard_thresholding_fraction)
o1_real[
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes
] = F.relu(
torch.einsum(
"nyxbi,bio->nyxbo",
x_real[
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes
],
self.w1[0],
)
- torch.einsum(
"nyxbi,bio->nyxbo",
x_imag[
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes
],
self.w1[1],
)
+ self.b1[0]
)
o1_imag[
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes
] = F.relu(
torch.einsum(
"nyxbi,bio->nyxbo",
x_imag[
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes
],
self.w1[0],
)
+ torch.einsum(
"nyxbi,bio->nyxbo",
x_real[
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes
],
self.w1[1],
)
+ self.b1[1]
)
o2[
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ..., 0
] = (
torch.einsum(
"nyxbi,bio->nyxbo",
o1_real[
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes
],
self.w2[0],
)
- torch.einsum(
"nyxbi,bio->nyxbo",
o1_imag[
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes
],
self.w2[1],
)
+ self.b2[0]
)
o2[
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ..., 1
] = (
torch.einsum(
"nyxbi,bio->nyxbo",
o1_imag[
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes
],
self.w2[0],
)
+ torch.einsum(
"nyxbi,bio->nyxbo",
o1_real[
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes
],
self.w2[1],
)
+ self.b2[1]
)
x = F.softshrink(o2, lambd=self.sparsity_threshold)
x = fft.view_as_complex(x)
# TODO(akamenev): replace the following branching with
# a one-liner, something like x.reshape(..., -1).squeeze(-1),
# but this currently fails during ONNX export.
if torch.onnx.is_in_onnx_export():
x = x.reshape(B, H, W // 2 + 1, C, 2)
else:
x = x.reshape(B, H, W // 2 + 1, C)
# Using ONNX friendly FFT functions
x = fft.irfft2(x, s=(H, W), dim=(1, 2), norm="ortho")
x = x.type(dtype)
return x + bias [docs]class Block(nn.Module):
"""AFNO block, spectral convolution and MLP
Parameters
----------
embed_dim : int
Embedded feature dimensionality
num_blocks : int, optional
Number of blocks used in the block diagonal weight matrix, by default 8
mlp_ratio : float, optional
Ratio of MLP latent variable size to input feature size, by default 4.0
drop : float, optional
Drop out rate in MLP, by default 0.0
activation_fn: nn.Module, optional
Activation function used in MLP, by default nn.GELU
norm_layer : nn.Module, optional
Normalization function, by default nn.LayerNorm
double_skip : bool, optional
Residual, by default True
sparsity_threshold : float, optional
Sparsity threshold (softshrink) of spectral features, by default 0.01
hard_thresholding_fraction : float, optional
Threshold for limiting number of modes used [0,1], by default 1
"""
def __init__(
self,
embed_dim: int,
num_blocks: int = 8,
mlp_ratio: float = 4.0,
drop: float = 0.0,
activation_fn: nn.Module = nn.GELU(),
norm_layer: nn.Module = nn.LayerNorm,
double_skip: bool = True,
sparsity_threshold: float = 0.01,
hard_thresholding_fraction: float = 1.0,
):
super().__init__()
self.norm1 = norm_layer(embed_dim)
self.filter = AFNO2DLayer(
embed_dim, num_blocks, sparsity_threshold, hard_thresholding_fraction
)
# self.drop_path = nn.Identity()
self.norm2 = norm_layer(embed_dim)
mlp_latent_dim = int(embed_dim * mlp_ratio)
self.mlp = AFNOMlp(
in_features=embed_dim,
latent_features=mlp_latent_dim,
out_features=embed_dim,
activation_fn=activation_fn,
drop=drop,
)
self.double_skip = double_skip
[docs] def forward(self, x: Tensor) -> Tensor:
residual = x
x = self.norm1(x)
x = self.filter(x)
if self.double_skip:
x = x + residual
residual = x
x = self.norm2(x)
x = self.mlp(x)
x = x + residual
return x [docs]class PatchEmbed(nn.Module):
"""Patch embedding layer
Converts 2D patch into a 1D vector for input to AFNO
Parameters
----------
img_size : Tuple[int, int]
Input image dimensions (height, width)
in_channels : int
Number of input channels
patch_size : Tuple[int, int], optional
Size of image patches, by default (16, 16)
embed_dim : int, optional
Embedded channel size, by default 256
"""
def __init__(
self,
img_size: Tuple[int, int],
in_channels: int,
patch_size: Tuple[int, int] = (16, 16),
embed_dim: int = 256,
):
super().__init__()
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
)
[docs] def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x [docs]class AFNO(Module):
"""Adaptive Fourier neural operator (AFNO) model.
Note
----
AFNO is a model that is designed for 2D images only.
Parameters
----------
img_size : Tuple[int, int]
Input image dimensions (height, width)
in_channels : int
Number of input channels
out_channels: int
Number of output channels
patch_size : Tuple[int, int], optional
Size of image patches, by default (16, 16)
embed_dim : int, optional
Embedded channel size, by default 256
depth : int, optional
Number of AFNO layers, by default 4
mlp_ratio : float, optional
Ratio of layer MLP latent variable size to input feature size, by default 4.0
drop_rate : float, optional
Drop out rate in layer MLPs, by default 0.0
num_blocks : int, optional
Number of blocks in the block-diag frequency weight matrices, by default 16
sparsity_threshold : float, optional
Sparsity threshold (softshrink) of spectral features, by default 0.01
hard_thresholding_fraction : float, optional
Threshold for limiting number of modes used [0,1], by default 1
Example
-------
>>> model = modulus.models.afno.AFNO(
... img_size=(32, 32),
... in_channels=2,
... out_channels=1,
... patch_size=(8, 8),
... embed_dim=16,
... depth=2,
... num_blocks=2,
... )
>>> input = torch.randn(32, 2, 32, 32) #(N, C, H, W)
>>> output = model(input)
>>> output.size()
torch.Size([32, 1, 32, 32])
Note
----
Reference: Guibas, John, et al. "Adaptive fourier neural operators:
Efficient token mixers for transformers." arXiv preprint arXiv:2111.13587 (2021).
"""
def __init__(
self,
img_size: Tuple[int, int],
in_channels: int,
out_channels: int,
patch_size: Tuple[int, int] = (16, 16),
embed_dim: int = 256,
depth: int = 4,
mlp_ratio: float = 4.0,
drop_rate: float = 0.0,
num_blocks: int = 16,
sparsity_threshold: float = 0.01,
hard_thresholding_fraction: float = 1.0,
) -> None:
super().__init__(meta=MetaData())
assert (
img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0
), f"img_size {img_size} should be divisible by patch_size {patch_size}"
self.in_chans = in_channels
self.out_chans = out_channels
self.img_size = img_size
self.patch_size = patch_size
self.num_features = self.embed_dim = embed_dim
self.num_blocks = num_blocks
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.patch_embed = PatchEmbed(
img_size=img_size,
in_channels=self.in_chans,
patch_size=self.patch_size,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
self.h = img_size[0] // self.patch_size[0]
self.w = img_size[1] // self.patch_size[1]
self.blocks = nn.ModuleList(
[
Block(
embed_dim=embed_dim,
num_blocks=self.num_blocks,
mlp_ratio=mlp_ratio,
drop=drop_rate,
norm_layer=norm_layer,
sparsity_threshold=sparsity_threshold,
hard_thresholding_fraction=hard_thresholding_fraction,
)
for i in range(depth)
]
)
self.head = nn.Linear(
embed_dim,
self.out_chans * self.patch_size[0] * self.patch_size[1],
bias=False,
)
torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
"""Init model weights"""
if isinstance(m, nn.Linear):
torch.nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
# What is this for
# @torch.jit.ignore
# def no_weight_decay(self):
# return {"pos_embed", "cls_token"}
[docs] def forward_features(self, x: Tensor) -> Tensor:
"""Forward pass of core AFNO"""
B = x.shape[0]
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
x = x.reshape(B, self.h, self.w, self.embed_dim)
for blk in self.blocks:
x = blk(x)
return x
[docs] def forward(self, x: Tensor) -> Tensor:
x = self.forward_features(x)
x = self.head(x)
# Correct tensor shape back into [B, C, H, W]
# [b h w (p1 p2 c_out)]
out = x.view(list(x.shape[:-1]) + [self.patch_size[0], self.patch_size[1], -1])
# [b h w p1 p2 c_out]
out = torch.permute(out, (0, 5, 1, 3, 2, 4))
# [b c_out, h, p1, w, p2]
out = out.reshape(list(out.shape[:2]) + [self.img_size[0], self.img_size[1]])
# [b c_out, (h*p1), (w*p2)]
return out