deeplearning/modulus/modulus-core-v021/_modules/modulus/models/rnn/rnn_one2many.html

Core v0.2.1

Source code for modulus.models.rnn.rnn_one2many

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import modulus

from torch import Tensor
from dataclasses import dataclass
from typing import Union, List
from modulus.models.meta import ModelMetaData
from modulus.models.module import Module
from modulus.models.rnn.layers import (
    _ConvLayer,
    _TransposeConvLayer,
    _ConvGRULayer,
    _ConvResidualBlock,
)


[docs]@dataclass class MetaData(ModelMetaData): name: str = "One2ManyRNN" # Optimization jit: bool = False cuda_graphs: bool = False amp: bool = True torch_fx: bool = True # Inference onnx: bool = False onnx_runtime: bool = False # Physics informed func_torch: bool = False auto_grad: bool = False
[docs]class One2ManyRNN(Module): """A RNN model with encoder/decoder for 2d/3d problems that provides predictions based on single initial condition. Parameters ---------- input_channels : int Number of channels in the input dimension : int, optional Spatial dimension of the input. Only 2d and 3d are supported, by default 2 nr_latent_channels : int, optional Channels for encoding/decoding, by default 512 nr_residual_blocks : int, optional Number of residual blocks, by default 2 activation_fn : Union[nn.Module, List[nn.Module]], optional Activation function to use, by default nn.ReLU() nr_downsamples : int, optional Number of downsamples, by default 2 nr_tsteps : int, optional Time steps to predict, by default 32 Example ------- >>> model = modulus.models.rnn.One2ManyRNN( ... input_channels=6, ... dimension=2, ... nr_latent_channels=32, ... activation_fn=torch.nn.ReLU(), ... nr_downsamples=2, ... nr_tsteps=16, ... ) >>> input = invar = torch.randn(4, 6, 1, 16, 16) # [N, C, T, H, W] >>> output = model(input) >>> output.size() torch.Size([4, 6, 16, 16, 16]) """ def __init__( self, input_channels: int, dimension: int = 2, nr_latent_channels: int = 512, nr_residual_blocks: int = 2, activation_fn: Union[nn.Module, List[nn.Module]] = nn.ReLU(), nr_downsamples: int = 2, nr_tsteps: int = 32, ) -> None: super().__init__(meta=MetaData()) self.nr_tsteps = nr_tsteps self.nr_residual_blocks = nr_residual_blocks self.nr_downsamples = nr_downsamples self.encoder_layers = nn.ModuleList() channels_out = nr_latent_channels # check valid dimensions if dimension not in [2, 3]: raise ValueError("Only 2D and 3D spatial dimensions are supported") for i in range(nr_downsamples): for j in range(nr_residual_blocks): stride = 1 if i == 0 and j == 0: channels_in = input_channels else: channels_in = channels_out if (j == nr_residual_blocks - 1) and (i < nr_downsamples - 1): channels_out = channels_out * 2 stride = 2 self.encoder_layers.append( _ConvResidualBlock( in_channels=channels_in, out_channels=channels_out, stride=stride, dimension=dimension, gated=True, layer_normalization=False, begin_activation_fn=not ((i == 0) and (j == 0)), activation_fn=activation_fn, ) ) self.rnn_layer = _ConvGRULayer( in_features=channels_out, hidden_size=channels_out, dimension=dimension ) self.conv_layers = nn.ModuleList() self.decoder_layers = nn.ModuleList() for i in range(nr_downsamples): self.upsampling_layers = nn.ModuleList() channels_in = channels_out channels_out = channels_out // 2 self.upsampling_layers.append( _TransposeConvLayer( in_channels=channels_in, out_channels=channels_out, kernel_size=4, stride=2, dimension=dimension, ) ) for j in range(nr_residual_blocks): self.upsampling_layers.append( _ConvResidualBlock( in_channels=channels_out, out_channels=channels_out, stride=1, dimension=dimension, gated=True, layer_normalization=False, begin_activation_fn=not ((i == 0) and (j == 0)), activation_fn=activation_fn, ) ) self.conv_layers.append( _ConvLayer( in_channels=channels_in, out_channels=nr_latent_channels, kernel_size=1, stride=1, dimension=dimension, ) ) self.decoder_layers.append(self.upsampling_layers) if dimension == 2: self.final_conv = nn.Conv2d( nr_latent_channels, input_channels, (1, 1), (1, 1), padding="valid" ) else: # dimension is 3 self.final_conv = nn.Conv3d( nr_latent_channels, input_channels, (1, 1, 1), (1, 1, 1), padding="valid", )
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass Parameters ---------- x : Tensor Expects a tensor of size [N, C, 1, H, W] for 2D or [N, C, 1, D, H, W] for 3D Where, N is the batch size, C is the number of channels, 1 is the number of input timesteps and D, H, W are spatial dimensions. Returns ------- Tensor Size [N, C, T, H, W] for 2D or [N, C, T, D, H, W] for 3D. Where, T is the number of timesteps being predicted. """ # Encoding step encoded_inputs = [] for t in range(1): x_in = x[:, :, t, ...] for layer in self.encoder_layers: x_in = layer(x_in) encoded_inputs.append(x_in) # RNN step rnn_output = [] for t in range(self.nr_tsteps): if t == 0: h = torch.zeros(list(x_in.size())).to(x.device) x_in_rnn = encoded_inputs[0] h = self.rnn_layer(x_in_rnn, h) x_in_rnn = h rnn_output.append(h) decoded_output = [] for t in range(self.nr_tsteps): x_out = rnn_output[t] # Decoding step latent_context_grid = [] for conv_layer, decoder in zip(self.conv_layers, self.decoder_layers): latent_context_grid.append(conv_layer(x_out)) upsampling_layers = decoder for upsampling_layer in upsampling_layers: x_out = upsampling_layer(x_out) # Add a convolution here to make the channel dimensions same as output # Only last latent context grid is used, but mult-resolution is available out = self.final_conv(latent_context_grid[-1]) decoded_output.append(out) decoded_output = torch.stack(decoded_output, dim=2) return decoded_output
© Copyright 2023, NVIDIA Modulus Team. Last updated on Sep 21, 2023.