# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Architecture/Model configs """ from dataclasses import dataclass , field from hydra.core.config_store import ConfigStore from omegaconf import MISSING , SI , II from typing import Any , Union , List , Dict , Tuple [docs] @dataclass class ModelConf : arch_type : str = MISSING input_keys : Any = MISSING output_keys : Any = MISSING detach_keys : Any = MISSING scaling : Any = None [docs] @dataclass class AFNOConf ( ModelConf ): arch_type : str = "afno" img_shape : Tuple [ int ] = MISSING patch_size : int = 16 embed_dim : int = 256 depth : int = 4 num_blocks : int = 8 [docs] @dataclass class DistributedAFNOConf ( ModelConf ): arch_type : str = "distributed_afno" img_shape : Tuple [ int ] = MISSING patch_size : int = 16 embed_dim : int = 256 depth : int = 4 num_blocks : int = 8 channel_parallel_inputs : bool = False channel_parallel_outputs : bool = False [docs] @dataclass class DeepOConf ( ModelConf ): arch_type : str = "deeponet" # branch_net: Union[Arch, str], # trunk_net: Union[Arch, str], trunk_dim : Any = None # Union[None, int] branch_dim : Any = None # Union[None, int] [docs] @dataclass class FNOConf ( ModelConf ): arch_type : str = "fno" dimension : int = MISSING # decoder_net: Arch nr_fno_layers : int = 4 fno_modes : Any = 16 # Union[int, List[int]] padding : int = 8 padding_type : str = "constant" activation_fn : str = "gelu" coord_features : bool = True [docs] @dataclass class FourierConf ( ModelConf ): arch_type : str = "fourier" frequencies : Any = "('axis', [i for i in range(10)])" frequencies_params : Any = "('axis', [i for i in range(10)])" activation_fn : str = "silu" layer_size : int = 512 nr_layers : int = 6 skip_connections : bool = False weight_norm : bool = True adaptive_activations : bool = False [docs] @dataclass class FullyConnectedConf ( ModelConf ): arch_type : str = "fully_connected" layer_size : int = 512 nr_layers : int = 6 skip_connections : bool = False activation_fn : str = "silu" adaptive_activations : bool = False weight_norm : bool = True [docs] @dataclass class ConvFullyConnectedConf ( ModelConf ): arch_type : str = "conv_fully_connected" layer_size : int = 512 nr_layers : int = 6 skip_connections : bool = False activation_fn : str = "silu" adaptive_activations : bool = False weight_norm : bool = True [docs] @dataclass class FusedMLPConf ( ModelConf ): arch_type : str = "fused_fully_connected" layer_size : int = 128 nr_layers : int = 6 activation_fn : str = "sigmoid" [docs] @dataclass class FusedFourierNetConf ( ModelConf ): arch_type : str = "fused_fourier" layer_size : int = 128 nr_layers : int = 6 activation_fn : str = "sigmoid" n_frequencies : int = 12 [docs] @dataclass class FusedGridEncodingNetConf ( ModelConf ): arch_type : str = "fused_hash_encoding" layer_size : int = 128 nr_layers : int = 6 activation_fn : str = "sigmoid" indexing : str = "Hash" n_levels : int = 16 n_features_per_level : int = 2 log2_hashmap_size : int = 19 base_resolution : int = 16 per_level_scale : float = 2.0 interpolation : str = "Smoothstep" [docs] @dataclass class MultiresolutionHashNetConf ( ModelConf ): arch_type : str = "hash_encoding" layer_size : int = 64 nr_layers : int = 3 skip_connections : bool = False weight_norm : bool = True adaptive_activations : bool = False bounds : Any = "[(1.0, 1.0), (1.0, 1.0)]" nr_levels : int = 16 nr_features_per_level : int = 2 log2_hashmap_size : int = 19 base_resolution : int = 2 finest_resolution : int = 32 [docs] @dataclass class HighwayFourierConf ( ModelConf ): arch_type : str = "highway_fourier" frequencies : Any = "('axis', [i for i in range(10)])" frequencies_params : Any = "('axis', [i for i in range(10)])" activation_fn : str = "silu" layer_size : int = 512 nr_layers : int = 6 skip_connections : bool = False weight_norm : bool = True adaptive_activations : bool = False transform_fourier_features : bool = True project_fourier_features : bool = False [docs] @dataclass class ModifiedFourierConf ( ModelConf ): arch_type : str = "modified_fourier" frequencies : Any = "('axis', [i for i in range(10)])" frequencies_params : Any = "('axis', [i for i in range(10)])" activation_fn : str = "silu" layer_size : int = 512 nr_layers : int = 6 skip_connections : bool = False weight_norm : bool = True adaptive_activations : bool = False [docs] @dataclass class MultiplicativeFilterConf ( ModelConf ): arch_type : str = "multiplicative_fourier" layer_size : int = 512 nr_layers : int = 6 skip_connections : bool = False activation_fn : str = "identity" filter_type : str = "fourier" weight_norm : bool = True input_scale : float = 10.0 gabor_alpha : float = 6.0 gabor_beta : float = 1.0 normalization : Any = ( None # Change to Union[None, Dict[str, Tuple[float, float]]] when supported ) [docs] @dataclass class MultiscaleFourierConf ( ModelConf ): arch_type : str = "multiscale_fourier" frequencies : Any = field ( default_factory = lambda : [ 32 ]) frequencies_params : Any = None activation_fn : str = "silu" layer_size : int = 512 nr_layers : int = 6 skip_connections : bool = False weight_norm : bool = True adaptive_activations : bool = False [docs] @dataclass class Pix2PixConf ( ModelConf ): arch_type : str = "pix2pix" dimension : int = MISSING conv_layer_size : int = 64 n_downsampling : int = 3 n_blocks : int = 3 scaling_factor : int = 1 batch_norm : bool = True padding_type : str = "reflect" activation_fn : str = "relu" [docs] @dataclass class SirenConf ( ModelConf ): arch_type : str = "siren" layer_size : int = 512 nr_layers : int = 6 first_omega : float = 30.0 omega : float = 30.0 normalization : Any = ( None # Change to Union[None, Dict[str, Tuple[float, float]]] when supported ) [docs] @dataclass class SRResConf ( ModelConf ): arch_type : str = "super_res" large_kernel_size : int = 7 small_kernel_size : int = 3 conv_layer_size : int = 32 n_resid_blocks : int = 8 scaling_factor : int = 8 activation_fn : str = "prelu" [docs] def register_arch_configs () -> None : # Information regarding multiple config groups # https://hydra.cc/docs/next/patterns/select_multiple_configs_from_config_group/ cs = ConfigStore . instance () cs . store ( group = "arch" , name = "fused_fully_connected" , node = { "fused_fully_connected" : FusedMLPConf ()}, ) cs . store ( group = "arch" , name = "fused_fourier" , node = { "fused_fourier" : FusedFourierNetConf ()}, ) cs . store ( group = "arch" , name = "fused_hash_encoding" , node = { "fused_hash_encoding" : FusedGridEncodingNetConf ()}, ) cs . store ( group = "arch" , name = "fully_connected" , node = { "fully_connected" : FullyConnectedConf ()}, ) cs . store ( group = "arch" , name = "conv_fully_connected" , node = { "conv_fully_connected" : ConvFullyConnectedConf ()}, ) cs . store ( group = "arch" , name = "fourier" , node = { "fourier" : FourierConf ()}, ) cs . store ( group = "arch" , name = "highway_fourier" , node = { "highway_fourier" : HighwayFourierConf ()}, ) cs . store ( group = "arch" , name = "modified_fourier" , node = { "modified_fourier" : ModifiedFourierConf ()}, ) cs . store ( group = "arch" , name = "multiplicative_fourier" , node = { "multiplicative_fourier" : MultiplicativeFilterConf ()}, ) cs . store ( group = "arch" , name = "multiscale_fourier" , node = { "multiscale_fourier" : MultiscaleFourierConf ()}, ) cs . store ( group = "arch" , name = "siren" , node = { "siren" : SirenConf ()}, ) cs . store ( group = "arch" , name = "hash_encoding" , node = { "hash_encoding" : MultiresolutionHashNetConf ()}, ) cs . store ( group = "arch" , name = "fno" , node = { "fno" : FNOConf ()}, ) cs . store ( group = "arch" , name = "afno" , node = { "afno" : AFNOConf ()}, ) cs . store ( group = "arch" , name = "distributed_afno" , node = { "distributed_afno" : DistributedAFNOConf ()}, ) cs . store ( group = "arch" , name = "deeponet" , node = { "deeponet" : DeepOConf ()}, ) cs . store ( group = "arch" , name = "super_res" , node = { "super_res" : SRResConf ()}, ) cs . store ( group = "arch" , name = "pix2pix" , node = { "pix2pix" : Pix2PixConf ()}, ) # Schemas for extending models # Info: https://hydra.cc/docs/next/patterns/extending_configs/ cs . store ( group = "arch" , name = "fully_connected_cfg" , node = FullyConnectedConf , ) cs . store ( group = "arch" , name = "conv_fully_connected_cfg" , node = ConvFullyConnectedConf , ) cs . store ( group = "arch" , name = "fused_mlp_cfg" , node = FusedMLPConf , ) cs . store ( group = "arch" , name = "fused_fourier_net_cfg" , node = FusedFourierNetConf , ) cs . store ( group = "arch" , name = "fused_grid_encoding_net_cfg" , node = FusedGridEncodingNetConf , ) cs . store ( group = "arch" , name = "fourier_cfg" , node = FourierConf , ) cs . store ( group = "arch" , name = "highway_fourier_cfg" , node = HighwayFourierConf , ) cs . store ( group = "arch" , name = "modified_fourier_cfg" , node = ModifiedFourierConf , ) cs . store ( group = "arch" , name = "multiplicative_fourier_cfg" , node = MultiplicativeFilterConf , ) cs . store ( group = "arch" , name = "multiscale_fourier_cfg" , node = MultiscaleFourierConf , ) cs . store ( group = "arch" , name = "siren_cfg" , node = SirenConf , ) cs . store ( group = "arch" , name = "hash_net_cfg" , node = MultiresolutionHashNetConf , ) cs . store ( group = "arch" , name = "fno_cfg" , node = FNOConf , ) cs . store ( group = "arch" , name = "afno_cfg" , node = AFNOConf , ) cs . store ( group = "arch" , name = "distributed_afno_cfg" , node = DistributedAFNOConf , ) cs . store ( group = "arch" , name = "deeponet_cfg" , node = DeepOConf , ) cs . store ( group = "arch" , name = "super_res_cfg" , node = SRResConf , ) cs . store ( group = "arch" , name = "pix2pix_cfg" , node = Pix2PixConf , )