deeplearning/modulus/modulus-core/_modules/modulus/models/afno/afno.html

Source code for modulus.models.afno.afno

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from functools import partial
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F

import modulus  # noqa: F401 for docs
import modulus.models.layers.fft as fft

from ..meta import ModelMetaData
from ..module import Module

Tensor = torch.Tensor


[docs]class AFNOMlp(nn.Module): """Fully-connected Multi-layer perception used inside AFNO Parameters ---------- in_features : int Input feature size latent_features : int Latent feature size out_features : int Output feature size activation_fn : nn.Module, optional Activation function, by default nn.GELU drop : float, optional Drop out rate, by default 0.0 """ def __init__( self, in_features: int, latent_features: int, out_features: int, activation_fn: nn.Module = nn.GELU(), drop: float = 0.0, ): super().__init__() self.fc1 = nn.Linear(in_features, latent_features) self.act = activation_fn self.fc2 = nn.Linear(latent_features, out_features) self.drop = nn.Dropout(drop)
[docs] def forward(self, x: Tensor) -> Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x
[docs]class AFNO2DLayer(nn.Module): """AFNO spectral convolution layer Parameters ---------- hidden_size : int Feature dimensionality num_blocks : int, optional Number of blocks used in the block diagonal weight matrix, by default 8 sparsity_threshold : float, optional Sparsity threshold (softshrink) of spectral features, by default 0.01 hard_thresholding_fraction : float, optional Threshold for limiting number of modes used [0,1], by default 1 hidden_size_factor : int, optional Factor to increase spectral features by after weight multiplication, by default 1 """ def __init__( self, hidden_size: int, num_blocks: int = 8, sparsity_threshold: float = 0.01, hard_thresholding_fraction: float = 1, hidden_size_factor: int = 1, ): super().__init__() if not (hidden_size % num_blocks == 0): raise ValueError( f"hidden_size {hidden_size} should be divisible by num_blocks {num_blocks}" ) self.hidden_size = hidden_size self.sparsity_threshold = sparsity_threshold self.num_blocks = num_blocks self.block_size = self.hidden_size // self.num_blocks self.hard_thresholding_fraction = hard_thresholding_fraction self.hidden_size_factor = hidden_size_factor self.scale = 0.02 self.w1 = nn.Parameter( self.scale * torch.randn( 2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor, ) ) self.b1 = nn.Parameter( self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor) ) self.w2 = nn.Parameter( self.scale * torch.randn( 2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size, ) ) self.b2 = nn.Parameter( self.scale * torch.randn(2, self.num_blocks, self.block_size) )
[docs] def forward(self, x: Tensor) -> Tensor: bias = x dtype = x.dtype x = x.float() B, H, W, C = x.shape # Using ONNX friendly FFT functions x = fft.rfft2(x, dim=(1, 2), norm="ortho") x_real, x_imag = fft.real(x), fft.imag(x) x_real = x_real.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) x_imag = x_imag.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) o1_real = torch.zeros( [ B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor, ], device=x.device, ) o1_imag = torch.zeros( [ B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor, ], device=x.device, ) o2 = torch.zeros(x_real.shape + (2,), device=x.device) total_modes = H // 2 + 1 kept_modes = int(total_modes * self.hard_thresholding_fraction) o1_real[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ] = F.relu( torch.einsum( "nyxbi,bio->nyxbo", x_real[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w1[0], ) - torch.einsum( "nyxbi,bio->nyxbo", x_imag[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w1[1], ) + self.b1[0] ) o1_imag[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ] = F.relu( torch.einsum( "nyxbi,bio->nyxbo", x_imag[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w1[0], ) + torch.einsum( "nyxbi,bio->nyxbo", x_real[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w1[1], ) + self.b1[1] ) o2[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ..., 0 ] = ( torch.einsum( "nyxbi,bio->nyxbo", o1_real[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w2[0], ) - torch.einsum( "nyxbi,bio->nyxbo", o1_imag[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w2[1], ) + self.b2[0] ) o2[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ..., 1 ] = ( torch.einsum( "nyxbi,bio->nyxbo", o1_imag[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w2[0], ) + torch.einsum( "nyxbi,bio->nyxbo", o1_real[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w2[1], ) + self.b2[1] ) x = F.softshrink(o2, lambd=self.sparsity_threshold) x = fft.view_as_complex(x) # TODO(akamenev): replace the following branching with # a one-liner, something like x.reshape(..., -1).squeeze(-1), # but this currently fails during ONNX export. if torch.onnx.is_in_onnx_export(): x = x.reshape(B, H, W // 2 + 1, C, 2) else: x = x.reshape(B, H, W // 2 + 1, C) # Using ONNX friendly FFT functions x = fft.irfft2(x, s=(H, W), dim=(1, 2), norm="ortho") x = x.type(dtype) return x + bias
[docs]class Block(nn.Module): """AFNO block, spectral convolution and MLP Parameters ---------- embed_dim : int Embedded feature dimensionality num_blocks : int, optional Number of blocks used in the block diagonal weight matrix, by default 8 mlp_ratio : float, optional Ratio of MLP latent variable size to input feature size, by default 4.0 drop : float, optional Drop out rate in MLP, by default 0.0 activation_fn: nn.Module, optional Activation function used in MLP, by default nn.GELU norm_layer : nn.Module, optional Normalization function, by default nn.LayerNorm double_skip : bool, optional Residual, by default True sparsity_threshold : float, optional Sparsity threshold (softshrink) of spectral features, by default 0.01 hard_thresholding_fraction : float, optional Threshold for limiting number of modes used [0,1], by default 1 """ def __init__( self, embed_dim: int, num_blocks: int = 8, mlp_ratio: float = 4.0, drop: float = 0.0, activation_fn: nn.Module = nn.GELU(), norm_layer: nn.Module = nn.LayerNorm, double_skip: bool = True, sparsity_threshold: float = 0.01, hard_thresholding_fraction: float = 1.0, ): super().__init__() self.norm1 = norm_layer(embed_dim) self.filter = AFNO2DLayer( embed_dim, num_blocks, sparsity_threshold, hard_thresholding_fraction ) # self.drop_path = nn.Identity() self.norm2 = norm_layer(embed_dim) mlp_latent_dim = int(embed_dim * mlp_ratio) self.mlp = AFNOMlp( in_features=embed_dim, latent_features=mlp_latent_dim, out_features=embed_dim, activation_fn=activation_fn, drop=drop, ) self.double_skip = double_skip
[docs] def forward(self, x: Tensor) -> Tensor: residual = x x = self.norm1(x) x = self.filter(x) if self.double_skip: x = x + residual residual = x x = self.norm2(x) x = self.mlp(x) x = x + residual return x
[docs]class PatchEmbed(nn.Module): """Patch embedding layer Converts 2D patch into a 1D vector for input to AFNO Parameters ---------- inp_shape : List[int] Input image dimensions [height, width] in_channels : int Number of input channels patch_size : List[int], optional Size of image patches, by default [16, 16] embed_dim : int, optional Embedded channel size, by default 256 """ def __init__( self, inp_shape: List[int], in_channels: int, patch_size: List[int] = [16, 16], embed_dim: int = 256, ): super().__init__() if len(inp_shape) != 2: raise ValueError("inp_shape should be a list of length 2") if len(patch_size) != 2: raise ValueError("patch_size should be a list of length 2") num_patches = (inp_shape[1] // patch_size[1]) * (inp_shape[0] // patch_size[0]) self.inp_shape = inp_shape self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=patch_size, stride=patch_size )
[docs] def forward(self, x: Tensor) -> Tensor: B, C, H, W = x.shape if not (H == self.inp_shape[0] and W == self.inp_shape[1]): raise ValueError( f"Input image size ({H}*{W}) doesn't match model ({self.inp_shape[0]}*{self.inp_shape[1]})." ) x = self.proj(x).flatten(2).transpose(1, 2) return x
[docs]@dataclass class MetaData(ModelMetaData): name: str = "AFNO" # Optimization jit: bool = False # ONNX Ops Conflict cuda_graphs: bool = True amp: bool = True # Inference onnx_cpu: bool = False # No FFT op on CPU onnx_gpu: bool = True onnx_runtime: bool = True # Physics informed var_dim: int = 1 func_torch: bool = False auto_grad: bool = False
[docs]class AFNO(Module): """Adaptive Fourier neural operator (AFNO) model. Note ---- AFNO is a model that is designed for 2D images only. Parameters ---------- inp_shape : List[int] Input image dimensions [height, width] in_channels : int Number of input channels out_channels: int Number of output channels patch_size : List[int], optional Size of image patches, by default [16, 16] embed_dim : int, optional Embedded channel size, by default 256 depth : int, optional Number of AFNO layers, by default 4 mlp_ratio : float, optional Ratio of layer MLP latent variable size to input feature size, by default 4.0 drop_rate : float, optional Drop out rate in layer MLPs, by default 0.0 num_blocks : int, optional Number of blocks in the block-diag frequency weight matrices, by default 16 sparsity_threshold : float, optional Sparsity threshold (softshrink) of spectral features, by default 0.01 hard_thresholding_fraction : float, optional Threshold for limiting number of modes used [0,1], by default 1 Example ------- >>> model = modulus.models.afno.AFNO( ... inp_shape=[32, 32], ... in_channels=2, ... out_channels=1, ... patch_size=(8, 8), ... embed_dim=16, ... depth=2, ... num_blocks=2, ... ) >>> input = torch.randn(32, 2, 32, 32) #(N, C, H, W) >>> output = model(input) >>> output.size() torch.Size([32, 1, 32, 32]) Note ---- Reference: Guibas, John, et al. "Adaptive fourier neural operators: Efficient token mixers for transformers." arXiv preprint arXiv:2111.13587 (2021). """ def __init__( self, inp_shape: List[int], in_channels: int, out_channels: int, patch_size: List[int] = [16, 16], embed_dim: int = 256, depth: int = 4, mlp_ratio: float = 4.0, drop_rate: float = 0.0, num_blocks: int = 16, sparsity_threshold: float = 0.01, hard_thresholding_fraction: float = 1.0, ) -> None: super().__init__(meta=MetaData()) if len(inp_shape) != 2: raise ValueError("inp_shape should be a list of length 2") if len(patch_size) != 2: raise ValueError("patch_size should be a list of length 2") if not ( inp_shape[0] % patch_size[0] == 0 and inp_shape[1] % patch_size[1] == 0 ): raise ValueError( f"input shape {inp_shape} should be divisible by patch_size {patch_size}" ) self.in_chans = in_channels self.out_chans = out_channels self.inp_shape = inp_shape self.patch_size = patch_size self.num_features = self.embed_dim = embed_dim self.num_blocks = num_blocks norm_layer = partial(nn.LayerNorm, eps=1e-6) self.patch_embed = PatchEmbed( inp_shape=inp_shape, in_channels=self.in_chans, patch_size=self.patch_size, embed_dim=embed_dim, ) num_patches = self.patch_embed.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) self.h = inp_shape[0] // self.patch_size[0] self.w = inp_shape[1] // self.patch_size[1] self.blocks = nn.ModuleList( [ Block( embed_dim=embed_dim, num_blocks=self.num_blocks, mlp_ratio=mlp_ratio, drop=drop_rate, norm_layer=norm_layer, sparsity_threshold=sparsity_threshold, hard_thresholding_fraction=hard_thresholding_fraction, ) for i in range(depth) ] ) self.head = nn.Linear( embed_dim, self.out_chans * self.patch_size[0] * self.patch_size[1], bias=False, ) torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) self.apply(self._init_weights) def _init_weights(self, m): """Init model weights""" if isinstance(m, nn.Linear): torch.nn.init.trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) # What is this for # @torch.jit.ignore # def no_weight_decay(self): # return {"pos_embed", "cls_token"}
[docs] def forward_features(self, x: Tensor) -> Tensor: """Forward pass of core AFNO""" B = x.shape[0] x = self.patch_embed(x) x = x + self.pos_embed x = self.pos_drop(x) x = x.reshape(B, self.h, self.w, self.embed_dim) for blk in self.blocks: x = blk(x) return x
[docs] def forward(self, x: Tensor) -> Tensor: x = self.forward_features(x) x = self.head(x) # Correct tensor shape back into [B, C, H, W] # [b h w (p1 p2 c_out)] out = x.view(list(x.shape[:-1]) + [self.patch_size[0], self.patch_size[1], -1]) # [b h w p1 p2 c_out] out = torch.permute(out, (0, 5, 1, 3, 2, 4)) # [b c_out, h, p1, w, p2] out = out.reshape(list(out.shape[:2]) + [self.inp_shape[0], self.inp_shape[1]]) # [b c_out, (h*p1), (w*p2)] return out
© Copyright 2023, NVIDIA Modulus Team. Last updated on Apr 19, 2024.