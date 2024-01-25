# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from typing import List , Dict from modulus.models.srrn import SRResNet from modulus.sym.key import Key from modulus.sym.models.arch import Arch from modulus.sym.models.activation import Activation , get_activation_fn Tensor = torch . Tensor [docs] class SRResNetArch ( Arch ): """3D super resolution network Based on the implementation: https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution Parameters ---------- input_keys : List[Key] Input key list output_keys : List[Key] Output key list detach_keys : List[Key], optional List of keys to detach gradients, by default [] large_kernel_size : int, optional convolutional kernel size for first and last convolution, by default 7 small_kernel_size : int, optional convolutional kernel size for internal convolutions, by default 3 conv_layer_size : int, optional Latent channel size, by default 32 n_resid_blocks : int, optional Number of residual blocks before , by default 8 scaling_factor : int, optional Scaling factor to increase the output feature size compared to the input (2, 4, or 8), by default 8 activation_fn : Activation, optional Activation function, by default Activation.PRELU """ def __init__ ( self , input_keys : List [ Key ], output_keys : List [ Key ], detach_keys : List [ Key ] = [], large_kernel_size : int = 7 , small_kernel_size : int = 3 , conv_layer_size : int = 32 , n_resid_blocks : int = 8 , scaling_factor : int = 8 , activation_fn : Activation = Activation . PRELU , ): super () . __init__ ( input_keys = input_keys , output_keys = output_keys , detach_keys = detach_keys ) in_channels = sum ( self . input_key_dict . values ()) out_channels = sum ( self . output_key_dict . values ()) activation_fn = get_activation_fn ( activation_fn ) self . srrn = SRResNet ( in_channels = in_channels , out_channels = out_channels , large_kernel_size = large_kernel_size , small_kernel_size = small_kernel_size , conv_layer_size = conv_layer_size , n_resid_blocks = n_resid_blocks , scaling_factor = scaling_factor , activation_fn = activation_fn , ) [docs] def forward ( self , in_vars : Dict [ str , Tensor ]) -> Dict [ str , Tensor ]: input = self . prepare_input ( in_vars , self . input_key_dict . keys (), detach_dict = self . detach_key_dict , dim = 1 , input_scales = self . input_scales , periodicity = self . periodicity , ) output = self . srrn ( input ) return self . prepare_output ( output , self . output_key_dict , dim = 1 , output_scales = self . output_scales )