deeplearning/modulus/modulus-core/_modules/modulus/models/rnn/rnn_one2many.html

Source code for modulus.models.rnn.rnn_one2many

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass

import torch
import torch.nn as nn
from torch import Tensor

import modulus  # noqa: F401 for docs
from modulus.models.layers import get_activation
from modulus.models.meta import ModelMetaData
from modulus.models.module import Module
from modulus.models.rnn.layers import (
    _ConvGRULayer,
    _ConvLayer,
    _ConvResidualBlock,
    _TransposeConvLayer,
)


[docs]@dataclass class MetaData(ModelMetaData): name: str = "One2ManyRNN" # Optimization jit: bool = False cuda_graphs: bool = False amp: bool = True torch_fx: bool = True # Inference onnx: bool = False onnx_runtime: bool = False # Physics informed func_torch: bool = False auto_grad: bool = False
[docs]class One2ManyRNN(Module): """A RNN model with encoder/decoder for 2d/3d problems that provides predictions based on single initial condition. Parameters ---------- input_channels : int Number of channels in the input dimension : int, optional Spatial dimension of the input. Only 2d and 3d are supported, by default 2 nr_latent_channels : int, optional Channels for encoding/decoding, by default 512 nr_residual_blocks : int, optional Number of residual blocks, by default 2 activation_fn : str, optional Activation function to use, by default "relu" nr_downsamples : int, optional Number of downsamples, by default 2 nr_tsteps : int, optional Time steps to predict, by default 32 Example ------- >>> model = modulus.models.rnn.One2ManyRNN( ... input_channels=6, ... dimension=2, ... nr_latent_channels=32, ... activation_fn="relu", ... nr_downsamples=2, ... nr_tsteps=16, ... ) >>> input = invar = torch.randn(4, 6, 1, 16, 16) # [N, C, T, H, W] >>> output = model(input) >>> output.size() torch.Size([4, 6, 16, 16, 16]) """ def __init__( self, input_channels: int, dimension: int = 2, nr_latent_channels: int = 512, nr_residual_blocks: int = 2, activation_fn: str = "relu", nr_downsamples: int = 2, nr_tsteps: int = 32, ) -> None: super().__init__(meta=MetaData()) self.nr_tsteps = nr_tsteps self.nr_residual_blocks = nr_residual_blocks self.nr_downsamples = nr_downsamples self.encoder_layers = nn.ModuleList() channels_out = nr_latent_channels activation_fn = get_activation(activation_fn) # check valid dimensions if dimension not in [2, 3]: raise ValueError("Only 2D and 3D spatial dimensions are supported") for i in range(nr_downsamples): for j in range(nr_residual_blocks): stride = 1 if i == 0 and j == 0: channels_in = input_channels else: channels_in = channels_out if (j == nr_residual_blocks - 1) and (i < nr_downsamples - 1): channels_out = channels_out * 2 stride = 2 self.encoder_layers.append( _ConvResidualBlock( in_channels=channels_in, out_channels=channels_out, stride=stride, dimension=dimension, gated=True, layer_normalization=False, begin_activation_fn=not ((i == 0) and (j == 0)), activation_fn=activation_fn, ) ) self.rnn_layer = _ConvGRULayer( in_features=channels_out, hidden_size=channels_out, dimension=dimension ) self.conv_layers = nn.ModuleList() self.decoder_layers = nn.ModuleList() for i in range(nr_downsamples): self.upsampling_layers = nn.ModuleList() channels_in = channels_out channels_out = channels_out // 2 self.upsampling_layers.append( _TransposeConvLayer( in_channels=channels_in, out_channels=channels_out, kernel_size=4, stride=2, dimension=dimension, ) ) for j in range(nr_residual_blocks): self.upsampling_layers.append( _ConvResidualBlock( in_channels=channels_out, out_channels=channels_out, stride=1, dimension=dimension, gated=True, layer_normalization=False, begin_activation_fn=not ((i == 0) and (j == 0)), activation_fn=activation_fn, ) ) self.conv_layers.append( _ConvLayer( in_channels=channels_in, out_channels=nr_latent_channels, kernel_size=1, stride=1, dimension=dimension, ) ) self.decoder_layers.append(self.upsampling_layers) if dimension == 2: self.final_conv = nn.Conv2d( nr_latent_channels, input_channels, (1, 1), (1, 1), padding="valid" ) else: # dimension is 3 self.final_conv = nn.Conv3d( nr_latent_channels, input_channels, (1, 1, 1), (1, 1, 1), padding="valid", )
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass Parameters ---------- x : Tensor Expects a tensor of size [N, C, 1, H, W] for 2D or [N, C, 1, D, H, W] for 3D Where, N is the batch size, C is the number of channels, 1 is the number of input timesteps and D, H, W are spatial dimensions. Returns ------- Tensor Size [N, C, T, H, W] for 2D or [N, C, T, D, H, W] for 3D. Where, T is the number of timesteps being predicted. """ # Encoding step encoded_inputs = [] for t in range(1): x_in = x[:, :, t, ...] for layer in self.encoder_layers: x_in = layer(x_in) encoded_inputs.append(x_in) # RNN step rnn_output = [] for t in range(self.nr_tsteps): if t == 0: h = torch.zeros(list(x_in.size())).to(x.device) x_in_rnn = encoded_inputs[0] h = self.rnn_layer(x_in_rnn, h) x_in_rnn = h rnn_output.append(h) decoded_output = [] for t in range(self.nr_tsteps): x_out = rnn_output[t] # Decoding step latent_context_grid = [] for conv_layer, decoder in zip(self.conv_layers, self.decoder_layers): latent_context_grid.append(conv_layer(x_out)) upsampling_layers = decoder for upsampling_layer in upsampling_layers: x_out = upsampling_layer(x_out) # Add a convolution here to make the channel dimensions same as output # Only last latent context grid is used, but mult-resolution is available out = self.final_conv(latent_context_grid[-1]) decoded_output.append(out) decoded_output = torch.stack(decoded_output, dim=2) return decoded_output
© Copyright 2023, NVIDIA Modulus Team. Last updated on Jan 25, 2024.