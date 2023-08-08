# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from typing import Dict , List , Tuple import torch import torch.nn as nn import torch.nn.functional as F import torch.fft from torch import Tensor from modulus.sym.models.arch import Arch from modulus.sym.key import Key class Mlp ( nn . Module ): def __init__ ( self , in_features , hidden_features = None , out_features = None , act_layer = nn . GELU , drop = 0.0 , ): super () . __init__ () out_features = out_features or in_features hidden_features = hidden_features or in_features self . fc1 = nn . Linear ( in_features , hidden_features ) self . act = act_layer () self . fc2 = nn . Linear ( hidden_features , out_features ) self . drop = nn . Dropout ( drop ) def forward ( self , x ): x = self . fc1 ( x ) x = self . act ( x ) x = self . drop ( x ) x = self . fc2 ( x ) x = self . drop ( x ) return x class AFNO2D ( nn . Module ): def __init__ ( self , hidden_size , num_blocks = 8 , sparsity_threshold = 0.01 , hard_thresholding_fraction = 1 , hidden_size_factor = 1 , ): super () . __init__ () assert ( hidden_size % num_blocks == 0 ), f "hidden_size { hidden_size } should be divisble by num_blocks { num_blocks } " self . hidden_size = hidden_size self . sparsity_threshold = sparsity_threshold self . num_blocks = num_blocks self . block_size = self . hidden_size // self . num_blocks self . hard_thresholding_fraction = hard_thresholding_fraction self . hidden_size_factor = hidden_size_factor self . scale = 0.02 self . w1 = nn . Parameter ( self . scale * torch . randn ( 2 , self . num_blocks , self . block_size , self . block_size * self . hidden_size_factor , ) ) self . b1 = nn . Parameter ( self . scale * torch . randn ( 2 , self . num_blocks , self . block_size * self . hidden_size_factor ) ) self . w2 = nn . Parameter ( self . scale * torch . randn ( 2 , self . num_blocks , self . block_size * self . hidden_size_factor , self . block_size , ) ) self . b2 = nn . Parameter ( self . scale * torch . randn ( 2 , self . num_blocks , self . block_size ) ) def forward ( self , x ): bias = x dtype = x . dtype x = x . float () B , H , W , C = x . shape x = torch . fft . rfft2 ( x , dim = ( 1 , 2 ), norm = "ortho" ) x = x . reshape ( B , H , W // 2 + 1 , self . num_blocks , self . block_size ) o1_real = torch . zeros ( [ B , H , W // 2 + 1 , self . num_blocks , self . block_size * self . hidden_size_factor , ], device = x . device , ) o1_imag = torch . zeros ( [ B , H , W // 2 + 1 , self . num_blocks , self . block_size * self . hidden_size_factor , ], device = x . device , ) o2_real = torch . zeros ( x . shape , device = x . device ) o2_imag = torch . zeros ( x . shape , device = x . device ) total_modes = H // 2 + 1 kept_modes = int ( total_modes * self . hard_thresholding_fraction ) o1_real [ :, total_modes - kept_modes : total_modes + kept_modes , : kept_modes ] = F . relu ( torch . einsum ( "...bi,bio->...bo" , x [ :, total_modes - kept_modes : total_modes + kept_modes , : kept_modes ] . real , self . w1 [ 0 ], ) - torch . einsum ( "...bi,bio->...bo" , x [ :, total_modes - kept_modes : total_modes + kept_modes , : kept_modes ] . imag , self . w1 [ 1 ], ) + self . b1 [ 0 ] ) o1_imag [ :, total_modes - kept_modes : total_modes + kept_modes , : kept_modes ] = F . relu ( torch . einsum ( "...bi,bio->...bo" , x [ :, total_modes - kept_modes : total_modes + kept_modes , : kept_modes ] . imag , self . w1 [ 0 ], ) + torch . einsum ( "...bi,bio->...bo" , x [ :, total_modes - kept_modes : total_modes + kept_modes , : kept_modes ] . real , self . w1 [ 1 ], ) + self . b1 [ 1 ] ) o2_real [:, total_modes - kept_modes : total_modes + kept_modes , : kept_modes ] = ( torch . einsum ( "...bi,bio->...bo" , o1_real [ :, total_modes - kept_modes : total_modes + kept_modes , : kept_modes ], self . w2 [ 0 ], ) - torch . einsum ( "...bi,bio->...bo" , o1_imag [ :, total_modes - kept_modes : total_modes + kept_modes , : kept_modes ], self . w2 [ 1 ], ) + self . b2 [ 0 ] ) o2_imag [:, total_modes - kept_modes : total_modes + kept_modes , : kept_modes ] = ( torch . einsum ( "...bi,bio->...bo" , o1_imag [ :, total_modes - kept_modes : total_modes + kept_modes , : kept_modes ], self . w2 [ 0 ], ) + torch . einsum ( "...bi,bio->...bo" , o1_real [ :, total_modes - kept_modes : total_modes + kept_modes , : kept_modes ], self . w2 [ 1 ], ) + self . b2 [ 1 ] ) x = torch . stack ([ o2_real , o2_imag ], dim =- 1 ) x = F . softshrink ( x , lambd = self . sparsity_threshold ) x = torch . view_as_complex ( x ) x = x . reshape ( B , H , W // 2 + 1 , C ) x = torch . fft . irfft2 ( x , s = ( H , W ), dim = ( 1 , 2 ), norm = "ortho" ) x = x . type ( dtype ) return x + bias class Block ( nn . Module ): def __init__ ( self , dim , mlp_ratio = 4.0 , drop = 0.0 , act_layer = nn . GELU , norm_layer = nn . LayerNorm , double_skip = True , num_blocks = 8 , sparsity_threshold = 0.01 , hard_thresholding_fraction = 1.0 , ): super () . __init__ () self . norm1 = norm_layer ( dim ) self . filter = AFNO2D ( dim , num_blocks , sparsity_threshold , hard_thresholding_fraction ) # self.drop_path = nn.Identity() self . norm2 = norm_layer ( dim ) mlp_hidden_dim = int ( dim * mlp_ratio ) self . mlp = Mlp ( in_features = dim , hidden_features = mlp_hidden_dim , act_layer = act_layer , drop = drop , ) self . double_skip = double_skip def forward ( self , x ): residual = x x = self . norm1 ( x ) x = self . filter ( x ) if self . double_skip : x = x + residual residual = x x = self . norm2 ( x ) x = self . mlp ( x ) x = x + residual return x class AFNONet ( nn . Module ): def __init__ ( self , img_size = ( 720 , 1440 ), patch_size = ( 16 , 16 ), in_channels = 1 , out_channels = 1 , embed_dim = 768 , depth = 12 , mlp_ratio = 4.0 , drop_rate = 0.0 , num_blocks = 16 , sparsity_threshold = 0.01 , hard_thresholding_fraction = 1.0 , ) -> None : super () . __init__ () assert ( img_size [ 0 ] % patch_size [ 0 ] == 0 and img_size [ 1 ] % patch_size [ 1 ] == 0 ), f "img_size { img_size } should be divisible by patch_size { patch_size } " self . in_chans = in_channels self . out_chans = out_channels self . img_size = img_size self . patch_size = patch_size self . num_features = self . embed_dim = embed_dim self . num_blocks = num_blocks norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) self . patch_embed = PatchEmbed ( img_size = img_size , patch_size = self . patch_size , in_chans = self . in_chans , embed_dim = embed_dim , ) num_patches = self . patch_embed . num_patches self . pos_embed = nn . Parameter ( torch . zeros ( 1 , num_patches , embed_dim )) self . pos_drop = nn . Dropout ( p = drop_rate ) self . h = img_size [ 0 ] // self . patch_size [ 0 ] self . w = img_size [ 1 ] // self . patch_size [ 1 ] self . blocks = nn . ModuleList ( [ Block ( dim = embed_dim , mlp_ratio = mlp_ratio , drop = drop_rate , norm_layer = norm_layer , num_blocks = self . num_blocks , sparsity_threshold = sparsity_threshold , hard_thresholding_fraction = hard_thresholding_fraction , ) for i in range ( depth ) ] ) self . head = nn . Linear ( embed_dim , self . out_chans * self . patch_size [ 0 ] * self . patch_size [ 1 ], bias = False , ) torch . nn . init . trunc_normal_ ( self . pos_embed , std = 0.02 ) self . apply ( self . _init_weights ) def _init_weights ( self , m ): if isinstance ( m , nn . Linear ): torch . nn . init . trunc_normal_ ( m . weight , std = 0.02 ) if isinstance ( m , nn . Linear ) and m . bias is not None : nn . init . constant_ ( m . bias , 0 ) elif isinstance ( m , nn . LayerNorm ): nn . init . constant_ ( m . bias , 0 ) nn . init . constant_ ( m . weight , 1.0 ) @torch . jit . ignore def no_weight_decay ( self ): return { "pos_embed" , "cls_token" } def forward_features ( self , x ): B = x . shape [ 0 ] x = self . patch_embed ( x ) x = x + self . pos_embed x = self . pos_drop ( x ) x = x . reshape ( B , self . h , self . w , self . embed_dim ) for blk in self . blocks : x = blk ( x ) return x def forward ( self , x : Tensor ) -> Tensor : x = self . forward_features ( x ) x = self . head ( x ) # Correct tensor shape back into [B, C, H, W] # [b h w (p1 p2 c_out)] out = x . view ( list ( x . shape [: - 1 ]) + [ self . patch_size [ 0 ], self . patch_size [ 1 ], - 1 ]) # [b h w p1 p2 c_out] out = torch . permute ( out , ( 0 , 5 , 1 , 3 , 2 , 4 )) # [b c_out, h, p1, w, p2] out = out . reshape ( list ( out . shape [: 2 ]) + [ self . img_size [ 0 ], self . img_size [ 1 ]]) # [b c_out, (h*p1), (w*p2)] return out class PatchEmbed ( nn . Module ): def __init__ ( self , img_size = ( 224 , 224 ), patch_size = ( 16 , 16 ), in_chans = 3 , embed_dim = 768 ): super () . __init__ () num_patches = ( img_size [ 1 ] // patch_size [ 1 ]) * ( img_size [ 0 ] // patch_size [ 0 ]) self . img_size = img_size self . patch_size = patch_size self . num_patches = num_patches self . proj = nn . Conv2d ( in_chans , embed_dim , kernel_size = patch_size , stride = patch_size ) def forward ( self , x ): B , C , H , W = x . shape assert ( H == self . img_size [ 0 ] and W == self . img_size [ 1 ] ), f "Input image size ( { H } * { W } ) doesn't match model ( { self . img_size [ 0 ] } * { self . img_size [ 1 ] } )." x = self . proj ( x ) . flatten ( 2 ) . transpose ( 1 , 2 ) return x [docs] class AFNOArch ( Arch ): """Adaptive Fourier neural operator (AFNO) model. Note ---- AFNO is a model that is designed for 2D images only. Parameters ---------- input_keys : List[Key] Input key list. The key dimension size should equal the variables channel dim. output_keys : List[Key] Output key list. The key dimension size should equal the variables channel dim. img_shape : Tuple[int, int] Input image dimensions (height, width) detach_keys : List[Key], optional List of keys to detach gradients, by default [] patch_size : int, optional Size of image patchs, by default 16 embed_dim : int, optional Embedded channel size, by default 256 depth : int, optional Number of AFNO layers, by default 4 num_blocks : int, optional Number of blocks in the frequency weight matrices, by default 4 Variable Shape -------------- - Input variable tensor shape: :math:`[N, size, H, W]` - Output variable tensor shape: :math:`[N, size, H, W]` Example ------- >>> afno = .afno.AFNOArch([Key("x", size=2)], [Key("y", size=2)], (64, 64)) >>> model = afno.make_node() >>> input = {"x": torch.randn(20, 2, 64, 64)} >>> output = model.evaluate(input) """ def __init__ ( self , input_keys : List [ Key ], output_keys : List [ Key ], img_shape : Tuple [ int , int ], detach_keys : List [ Key ] = [], patch_size : int = 16 , embed_dim : int = 256 , depth : int = 4 , num_blocks : int = 4 , ) -> None : super () . __init__ ( input_keys = input_keys , output_keys = output_keys ) self . input_keys = input_keys self . output_keys = output_keys self . detach_keys = detach_keys self . input_key_dict = { var . name : var . size for var in self . input_keys } self . output_key_dict = { var . name : var . size for var in self . output_keys } in_channels = sum ( self . input_key_dict . values ()) out_channels = sum ( self . output_key_dict . values ()) self . _impl = AFNONet ( in_channels = in_channels , out_channels = out_channels , patch_size = ( patch_size , patch_size ), img_size = img_shape , embed_dim = embed_dim , depth = depth , num_blocks = num_blocks , ) [docs] def forward ( self , in_vars : Dict [ str , Tensor ]) -> Dict [ str , Tensor ]: x = self . prepare_input ( in_vars , mask = self . input_key_dict . keys (), detach_dict = self . detach_key_dict , dim = 1 , input_scales = self . input_scales , ) y = self . _impl ( x ) return self . prepare_output ( y , output_var = self . output_key_dict , dim = 1 , output_scales = self . output_scales )