Source code for modulus.sym.models.super_res_net

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from typing import List, Dict

from modulus.models.srrn import SRResNet
from modulus.sym.key import Key
from modulus.sym.models.arch import Arch
from modulus.sym.models.activation import Activation, get_activation_fn

Tensor = torch.Tensor

[docs]class SRResNetArch(Arch): """3D super resolution network Based on the implementation: Parameters ---------- input_keys : List[Key] Input key list output_keys : List[Key] Output key list detach_keys : List[Key], optional List of keys to detach gradients, by default [] large_kernel_size : int, optional convolutional kernel size for first and last convolution, by default 7 small_kernel_size : int, optional convolutional kernel size for internal convolutions, by default 3 conv_layer_size : int, optional Latent channel size, by default 32 n_resid_blocks : int, optional Number of residual blocks before , by default 8 scaling_factor : int, optional Scaling factor to increase the output feature size compared to the input (2, 4, or 8), by default 8 activation_fn : Activation, optional Activation function, by default Activation.PRELU """ def __init__( self, input_keys: List[Key], output_keys: List[Key], detach_keys: List[Key] = [], large_kernel_size: int = 7, small_kernel_size: int = 3, conv_layer_size: int = 32, n_resid_blocks: int = 8, scaling_factor: int = 8, activation_fn: Activation = Activation.PRELU, ): super().__init__( input_keys=input_keys, output_keys=output_keys, detach_keys=detach_keys ) in_channels = sum(self.input_key_dict.values()) out_channels = sum(self.output_key_dict.values()) activation_fn = get_activation_fn(activation_fn) self.srrn = SRResNet( in_channels=in_channels, out_channels=out_channels, large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size, conv_layer_size=conv_layer_size, n_resid_blocks=n_resid_blocks, scaling_factor=scaling_factor, activation_fn=activation_fn, )
[docs] def forward(self, in_vars: Dict[str, Tensor]) -> Dict[str, Tensor]: input = self.prepare_input( in_vars, self.input_key_dict.keys(), detach_dict=self.detach_key_dict, dim=1, input_scales=self.input_scales, periodicity=self.periodicity, ) output = self.srrn(input) return self.prepare_output( output, self.output_key_dict, dim=1, output_scales=self.output_scales )
© Copyright 2023, NVIDIA Modulus Team. Last updated on Jan 25, 2024.