Source code for physicsnemo.nn.module.fno_layers

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Fourier Neural Operator (FNO) encoder layers.

This module contains reusable FNO encoder building blocks that can be used
in various FNO-based architectures.
"""

from typing import List, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor

import physicsnemo.nn as layers
from physicsnemo.core.module import Module


[docs] class FNO1DEncoder(Module): r"""1D Spectral encoder for FNO. This encoder applies a lifting network followed by spectral convolution layers in the Fourier domain for 1D input data. Parameters ---------- in_channels : int, optional, default=1 Number of input channels. num_fno_layers : int, optional, default=4 Number of spectral convolutional layers. fno_layer_size : int, optional, default=32 Latent features size in spectral convolutions. num_fno_modes : Union[int, List[int]], optional, default=16 Number of Fourier modes kept in spectral convolutions. padding : Union[int, List[int]], optional, default=8 Domain padding for spectral convolutions. padding_type : str, optional, default="constant" Type of padding for spectral convolutions. activation_fn : nn.Module, optional, default=nn.GELU() Activation function. coord_features : bool, optional, default=True Use coordinate grid as additional feature map. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, L)` where :math:`B` is batch size, :math:`C_{in}` is the number of input channels, and :math:`L` is the sequence length (spatial dimension). Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{latent}, L)` where :math:`C_{latent}` is ``fno_layer_size``. Examples -------- >>> import torch >>> encoder = FNO1DEncoder(in_channels=3, fno_layer_size=32, num_fno_modes=8) >>> x = torch.randn(4, 3, 64) >>> output = encoder(x) >>> output.shape torch.Size([4, 32, 64]) """ def __init__( self, in_channels: int = 1, num_fno_layers: int = 4, fno_layer_size: int = 32, num_fno_modes: Union[int, List[int]] = 16, padding: Union[int, List[int]] = 8, padding_type: str = "constant", activation_fn: nn.Module = nn.GELU(), coord_features: bool = True, ) -> None: super().__init__() self._input_channels = in_channels self.in_channels = in_channels self.num_fno_layers = num_fno_layers self.fno_width = fno_layer_size self.activation_fn = activation_fn # Add relative coordinate feature self.coord_features = coord_features if self.coord_features: self.in_channels = self.in_channels + 1 # Padding values for spectral conv if isinstance(padding, int): padding = [padding] self.pad = padding[:1] self.ipad = [-pad if pad > 0 else None for pad in self.pad] self.padding_type = padding_type if isinstance(num_fno_modes, int): num_fno_modes = [num_fno_modes] # build lift self._build_lift_network() self._build_fno(num_fno_modes) def _build_lift_network(self) -> None: r"""Construct network for lifting variables to latent space.""" self.lift_network = torch.nn.Sequential() self.lift_network.append( layers.Conv1dFCLayer(self.in_channels, int(self.fno_width / 2)) ) self.lift_network.append(self.activation_fn) self.lift_network.append( layers.Conv1dFCLayer(int(self.fno_width / 2), self.fno_width) ) def _build_fno(self, num_fno_modes: List[int]) -> None: r"""Construct FNO spectral convolution layers. Parameters ---------- num_fno_modes : List[int] Number of Fourier modes kept in spectral convolutions. """ # Build Neural Fourier Operators self.spconv_layers = nn.ModuleList() self.conv_layers = nn.ModuleList() for _ in range(self.num_fno_layers): self.spconv_layers.append( layers.SpectralConv1d(self.fno_width, self.fno_width, num_fno_modes[0]) ) self.conv_layers.append(nn.Conv1d(self.fno_width, self.fno_width, 1))
[docs] def forward(self, x: Float[Tensor, "B C_in L"]) -> Float[Tensor, "B C_latent L"]: r"""Forward pass of the 1D FNO encoder.""" # Input validation: single check for ndim and channels if not torch.compiler.is_compiling(): if x.ndim != 3 or x.shape[1] != self._input_channels: raise ValueError( f"Expected 3D input (B, {self._input_channels}, L), " f"got {x.ndim}D tensor with shape {tuple(x.shape)}" ) # Add coordinate features if enabled if self.coord_features: coord_feat = self._meshgrid(list(x.shape), x.device) x = torch.cat((x, coord_feat), dim=1) # Lift input to latent space x = self.lift_network(x) # Apply padding for spectral convolution x = F.pad(x, (0, self.pad[0]), mode=self.padding_type) # Apply spectral convolution layers for k, conv_w in enumerate(zip(self.conv_layers, self.spconv_layers)): conv, w = conv_w if k < len(self.conv_layers) - 1: x = self.activation_fn(conv(x) + w(x)) else: x = conv(x) + w(x) # Remove padding x = x[..., : self.ipad[0]] return x
def _meshgrid(self, shape: List[int], device: torch.device) -> Tensor: r"""Create 1D meshgrid feature. Parameters ---------- shape : List[int] Tensor shape as ``[batch, channels, L]``. device : torch.device Device model is on. Returns ------- Tensor Meshgrid tensor of shape :math:`(B, 1, L)`. """ bsize, size_x = shape[0], shape[2] grid_x = torch.linspace(0, 1, size_x, dtype=torch.float32, device=device) grid_x = grid_x.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1) return grid_x
[docs] def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: r"""Convert from grid-based (image) to point-based representation. Parameters ---------- value : Tensor Grid tensor of shape :math:`(B, C, L)`. Returns ------- Tuple[Tensor, List[int]] Tuple of (flattened tensor, original shape). """ y_shape = list(value.size()) output = torch.permute(value, (0, 2, 1)) return output.reshape(-1, output.size(-1)), y_shape
[docs] def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: r"""Convert from point-based to grid-based (image) representation. Parameters ---------- value : Tensor Point tensor of shape :math:`(B \times X, C)`. shape : List[int] Original grid shape as ``[B, C, L]``. Returns ------- Tensor Grid tensor of shape :math:`(B, C, L)`. """ output = value.reshape(shape[0], shape[2], value.size(-1)) return torch.permute(output, (0, 2, 1))
[docs] class FNO2DEncoder(Module): r"""2D Spectral encoder for FNO. This encoder applies a lifting network followed by spectral convolution layers in the Fourier domain for 2D input data. Parameters ---------- in_channels : int, optional, default=1 Number of input channels. num_fno_layers : int, optional, default=4 Number of spectral convolutional layers. fno_layer_size : int, optional, default=32 Latent features size in spectral convolutions. num_fno_modes : Union[int, List[int]], optional, default=16 Number of Fourier modes kept in spectral convolutions. padding : Union[int, List[int]], optional, default=8 Domain padding for spectral convolutions. padding_type : str, optional, default="constant" Type of padding for spectral convolutions. activation_fn : nn.Module, optional, default=nn.GELU() Activation function. coord_features : bool, optional, default=True Use coordinate grid as additional feature map. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, H, W)` where :math:`B` is batch size, :math:`C_{in}` is the number of input channels, and :math:`H, W` are spatial dimensions. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{latent}, H, W)` where :math:`C_{latent}` is ``fno_layer_size``. Examples -------- >>> import torch >>> encoder = FNO2DEncoder(in_channels=3, fno_layer_size=32, num_fno_modes=8) >>> x = torch.randn(4, 3, 32, 32) >>> output = encoder(x) >>> output.shape torch.Size([4, 32, 32, 32]) """ def __init__( self, in_channels: int = 1, num_fno_layers: int = 4, fno_layer_size: int = 32, num_fno_modes: Union[int, List[int]] = 16, padding: Union[int, List[int]] = 8, padding_type: str = "constant", activation_fn: nn.Module = nn.GELU(), coord_features: bool = True, ) -> None: super().__init__() self._input_channels = in_channels self.in_channels = in_channels self.num_fno_layers = num_fno_layers self.fno_width = fno_layer_size self.coord_features = coord_features self.activation_fn = activation_fn # Add relative coordinate feature if self.coord_features: self.in_channels = self.in_channels + 2 # Padding values for spectral conv if isinstance(padding, int): padding = [padding, padding] padding = padding + [0, 0] # Pad with zeros for smaller lists self.pad = padding[:2] self.ipad = [-pad if pad > 0 else None for pad in self.pad] self.padding_type = padding_type if isinstance(num_fno_modes, int): num_fno_modes = [num_fno_modes, num_fno_modes] # build lift self._build_lift_network() self._build_fno(num_fno_modes) def _build_lift_network(self) -> None: r"""Construct network for lifting variables to latent space.""" # Initial lift network self.lift_network = torch.nn.Sequential() self.lift_network.append( layers.Conv2dFCLayer(self.in_channels, int(self.fno_width / 2)) ) self.lift_network.append(self.activation_fn) self.lift_network.append( layers.Conv2dFCLayer(int(self.fno_width / 2), self.fno_width) ) def _build_fno(self, num_fno_modes: List[int]) -> None: r"""Construct FNO spectral convolution layers. Parameters ---------- num_fno_modes : List[int] Number of Fourier modes kept in spectral convolutions. """ # Build Neural Fourier Operators self.spconv_layers = nn.ModuleList() self.conv_layers = nn.ModuleList() for _ in range(self.num_fno_layers): self.spconv_layers.append( layers.SpectralConv2d( self.fno_width, self.fno_width, num_fno_modes[0], num_fno_modes[1] ) ) self.conv_layers.append(nn.Conv2d(self.fno_width, self.fno_width, 1))
[docs] def forward( self, x: Float[Tensor, "B C_in H W"] ) -> Float[Tensor, "B C_latent H W"]: r"""Forward pass of the 2D FNO encoder.""" # Input validation: single check for ndim and channels if not torch.compiler.is_compiling(): if x.ndim != 4 or x.shape[1] != self._input_channels: raise ValueError( f"Expected 4D input (B, {self._input_channels}, H, W), " f"got {x.ndim}D tensor with shape {tuple(x.shape)}" ) # Add coordinate features if enabled if self.coord_features: coord_feat = self._meshgrid(list(x.shape), x.device) x = torch.cat((x, coord_feat), dim=1) # Lift input to latent space x = self.lift_network(x) # Apply padding for spectral convolution x = F.pad(x, (0, self.pad[1], 0, self.pad[0]), mode=self.padding_type) # Apply spectral convolution layers for k, conv_w in enumerate(zip(self.conv_layers, self.spconv_layers)): conv, w = conv_w if k < len(self.conv_layers) - 1: x = self.activation_fn(conv(x) + w(x)) else: x = conv(x) + w(x) # Remove padding x = x[..., : self.ipad[0], : self.ipad[1]] return x
def _meshgrid(self, shape: List[int], device: torch.device) -> Tensor: r"""Create 2D meshgrid feature. Parameters ---------- shape : List[int] Tensor shape as ``[batch, channels, height, width]``. device : torch.device Device model is on. Returns ------- Tensor Meshgrid tensor of shape :math:`(B, 2, H, W)`. """ bsize, size_x, size_y = shape[0], shape[2], shape[3] grid_x = torch.linspace(0, 1, size_x, dtype=torch.float32, device=device) grid_y = torch.linspace(0, 1, size_y, dtype=torch.float32, device=device) grid_x, grid_y = torch.meshgrid(grid_x, grid_y, indexing="ij") grid_x = grid_x.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1) grid_y = grid_y.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1) return torch.cat((grid_x, grid_y), dim=1)
[docs] def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: r"""Convert from grid-based (image) to point-based representation. Parameters ---------- value : Tensor Grid tensor of shape :math:`(B, C, H, W)`. Returns ------- Tuple[Tensor, List[int]] Tuple of (flattened tensor, original shape). """ y_shape = list(value.size()) output = torch.permute(value, (0, 2, 3, 1)) return output.reshape(-1, output.size(-1)), y_shape
[docs] def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: r"""Convert from point-based to grid-based (image) representation. Parameters ---------- value : Tensor Point tensor of shape :math:`(B \times H \times W, C)`. shape : List[int] Original grid shape as ``[B, C, H, W]``. Returns ------- Tensor Grid tensor of shape :math:`(B, C, H, W)`. """ output = value.reshape(shape[0], shape[2], shape[3], value.size(-1)) return torch.permute(output, (0, 3, 1, 2))
[docs] class FNO3DEncoder(Module): r"""3D Spectral encoder for FNO. This encoder applies a lifting network followed by spectral convolution layers in the Fourier domain for 3D input data. Parameters ---------- in_channels : int, optional, default=1 Number of input channels. num_fno_layers : int, optional, default=4 Number of spectral convolutional layers. fno_layer_size : int, optional, default=32 Latent features size in spectral convolutions. num_fno_modes : Union[int, List[int]], optional, default=16 Number of Fourier modes kept in spectral convolutions. padding : Union[int, List[int]], optional, default=8 Domain padding for spectral convolutions. padding_type : str, optional, default="constant" Type of padding for spectral convolutions. activation_fn : nn.Module, optional, default=nn.GELU() Activation function. coord_features : bool, optional, default=True Use coordinate grid as additional feature map. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, D, H, W)` where :math:`B` is batch size, :math:`C_{in}` is the number of input channels, and :math:`D, H, W` are spatial dimensions. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{latent}, D, H, W)` where :math:`C_{latent}` is ``fno_layer_size``. Examples -------- >>> import torch >>> encoder = FNO3DEncoder(in_channels=3, fno_layer_size=32, num_fno_modes=8) >>> x = torch.randn(4, 3, 16, 16, 16) >>> output = encoder(x) >>> output.shape torch.Size([4, 32, 16, 16, 16]) """ def __init__( self, in_channels: int = 1, num_fno_layers: int = 4, fno_layer_size: int = 32, num_fno_modes: Union[int, List[int]] = 16, padding: Union[int, List[int]] = 8, padding_type: str = "constant", activation_fn: nn.Module = nn.GELU(), coord_features: bool = True, ) -> None: super().__init__() self._input_channels = in_channels self.in_channels = in_channels self.num_fno_layers = num_fno_layers self.fno_width = fno_layer_size self.coord_features = coord_features self.activation_fn = activation_fn # Add relative coordinate feature if self.coord_features: self.in_channels = self.in_channels + 3 # Padding values for spectral conv if isinstance(padding, int): padding = [padding, padding, padding] padding = padding + [0, 0, 0] # Pad with zeros for smaller lists self.pad = padding[:3] self.ipad = [-pad if pad > 0 else None for pad in self.pad] self.padding_type = padding_type if isinstance(num_fno_modes, int): num_fno_modes = [num_fno_modes, num_fno_modes, num_fno_modes] # build lift self._build_lift_network() self._build_fno(num_fno_modes) def _build_lift_network(self) -> None: r"""Construct network for lifting variables to latent space.""" # Initial lift network self.lift_network = torch.nn.Sequential() self.lift_network.append( layers.Conv3dFCLayer(self.in_channels, int(self.fno_width / 2)) ) self.lift_network.append(self.activation_fn) self.lift_network.append( layers.Conv3dFCLayer(int(self.fno_width / 2), self.fno_width) ) def _build_fno(self, num_fno_modes: List[int]) -> None: r"""Construct FNO spectral convolution layers. Parameters ---------- num_fno_modes : List[int] Number of Fourier modes kept in spectral convolutions. """ # Build Neural Fourier Operators self.spconv_layers = nn.ModuleList() self.conv_layers = nn.ModuleList() for _ in range(self.num_fno_layers): self.spconv_layers.append( layers.SpectralConv3d( self.fno_width, self.fno_width, num_fno_modes[0], num_fno_modes[1], num_fno_modes[2], ) ) self.conv_layers.append(nn.Conv3d(self.fno_width, self.fno_width, 1))
[docs] def forward( self, x: Float[Tensor, "B C_in D H W"] ) -> Float[Tensor, "B C_latent D H W"]: r"""Forward pass of the 3D FNO encoder.""" # Input validation: single check for ndim and channels if not torch.compiler.is_compiling(): if x.ndim != 5 or x.shape[1] != self._input_channels: raise ValueError( f"Expected 5D input (B, {self._input_channels}, D, H, W), " f"got {x.ndim}D tensor with shape {tuple(x.shape)}" ) # Add coordinate features if enabled if self.coord_features: coord_feat = self._meshgrid(list(x.shape), x.device) x = torch.cat((x, coord_feat), dim=1) # Lift input to latent space x = self.lift_network(x) # Apply padding for spectral convolution x = F.pad( x, (0, self.pad[2], 0, self.pad[1], 0, self.pad[0]), mode=self.padding_type, ) # Apply spectral convolution layers for k, conv_w in enumerate(zip(self.conv_layers, self.spconv_layers)): conv, w = conv_w if k < len(self.conv_layers) - 1: x = self.activation_fn(conv(x) + w(x)) else: x = conv(x) + w(x) # Remove padding x = x[..., : self.ipad[0], : self.ipad[1], : self.ipad[2]] return x
def _meshgrid(self, shape: List[int], device: torch.device) -> Tensor: r"""Create 3D meshgrid feature. Parameters ---------- shape : List[int] Tensor shape as ``[batch, channels, depth, height, width]``. device : torch.device Device model is on. Returns ------- Tensor Meshgrid tensor of shape :math:`(B, 3, D, H, W)`. """ bsize, size_x, size_y, size_z = shape[0], shape[2], shape[3], shape[4] grid_x = torch.linspace(0, 1, size_x, dtype=torch.float32, device=device) grid_y = torch.linspace(0, 1, size_y, dtype=torch.float32, device=device) grid_z = torch.linspace(0, 1, size_z, dtype=torch.float32, device=device) grid_x, grid_y, grid_z = torch.meshgrid(grid_x, grid_y, grid_z, indexing="ij") grid_x = grid_x.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1) grid_y = grid_y.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1) grid_z = grid_z.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1) return torch.cat((grid_x, grid_y, grid_z), dim=1)
[docs] def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: r"""Convert from grid-based (image) to point-based representation. Parameters ---------- value : Tensor Grid tensor of shape :math:`(B, C, D, H, W)`. Returns ------- Tuple[Tensor, List[int]] Tuple of (flattened tensor, original shape). """ y_shape = list(value.size()) output = torch.permute(value, (0, 2, 3, 4, 1)) return output.reshape(-1, output.size(-1)), y_shape
[docs] def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: r"""Convert from point-based to grid-based (image) representation. Parameters ---------- value : Tensor Point tensor of shape :math:`(B \times D \times H \times W, C)`. shape : List[int] Original grid shape as ``[B, C, D, H, W]``. Returns ------- Tensor Grid tensor of shape :math:`(B, C, D, H, W)`. """ output = value.reshape(shape[0], shape[2], shape[3], shape[4], value.size(-1)) return torch.permute(output, (0, 4, 1, 2, 3))
[docs] class FNO4DEncoder(Module): r"""4D Spectral encoder for FNO. This encoder applies a lifting network followed by spectral convolution layers in the Fourier domain for 4D input data (3D spatial + time). Parameters ---------- in_channels : int, optional, default=1 Number of input channels. num_fno_layers : int, optional, default=4 Number of spectral convolutional layers. fno_layer_size : int, optional, default=32 Latent features size in spectral convolutions. num_fno_modes : Union[int, List[int]], optional, default=16 Number of Fourier modes kept in spectral convolutions. padding : Union[int, List[int]], optional, default=8 Domain padding for spectral convolutions. padding_type : str, optional, default="constant" Type of padding for spectral convolutions. activation_fn : nn.Module, optional, default=nn.GELU() Activation function. coord_features : bool, optional, default=True Use coordinate grid as additional feature map. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, X, Y, Z, T)` where :math:`B` is batch size, :math:`C_{in}` is the number of input channels, and :math:`X, Y, Z, T` are spatial and temporal dimensions. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{latent}, X, Y, Z, T)` where :math:`C_{latent}` is ``fno_layer_size``. Examples -------- >>> import torch >>> encoder = FNO4DEncoder(in_channels=3, fno_layer_size=32, num_fno_modes=4) >>> x = torch.randn(2, 3, 8, 8, 8, 8) >>> output = encoder(x) >>> output.shape torch.Size([2, 32, 8, 8, 8, 8]) """ def __init__( self, in_channels: int = 1, num_fno_layers: int = 4, fno_layer_size: int = 32, num_fno_modes: Union[int, List[int]] = 16, padding: Union[int, List[int]] = 8, padding_type: str = "constant", activation_fn: nn.Module = nn.GELU(), coord_features: bool = True, ) -> None: super().__init__() self._input_channels = in_channels self.in_channels = in_channels self.num_fno_layers = num_fno_layers self.fno_width = fno_layer_size self.coord_features = coord_features self.activation_fn = activation_fn # Add relative coordinate feature if self.coord_features: self.in_channels = self.in_channels + 4 # Padding values for spectral conv if isinstance(padding, int): padding = [padding, padding, padding, padding] padding = padding + [0, 0, 0, 0] # Pad with zeros for smaller lists self.pad = padding[:4] self.ipad = [-pad if pad > 0 else None for pad in self.pad] self.padding_type = padding_type if isinstance(num_fno_modes, int): num_fno_modes = [num_fno_modes, num_fno_modes, num_fno_modes, num_fno_modes] # build lift self._build_lift_network() self._build_fno(num_fno_modes) def _build_lift_network(self) -> None: r"""Construct network for lifting variables to latent space.""" # Initial lift network self.lift_network = torch.nn.Sequential() self.lift_network.append( layers.ConvNdFCLayer(self.in_channels, int(self.fno_width / 2)) ) self.lift_network.append(self.activation_fn) self.lift_network.append( layers.ConvNdFCLayer(int(self.fno_width / 2), self.fno_width) ) def _build_fno(self, num_fno_modes: List[int]) -> None: r"""Construct FNO spectral convolution layers. Parameters ---------- num_fno_modes : List[int] Number of Fourier modes kept in spectral convolutions. """ # Build Neural Fourier Operators self.spconv_layers = nn.ModuleList() self.conv_layers = nn.ModuleList() for _ in range(self.num_fno_layers): self.spconv_layers.append( layers.SpectralConv4d( self.fno_width, self.fno_width, num_fno_modes[0], num_fno_modes[1], num_fno_modes[2], num_fno_modes[3], ) ) self.conv_layers.append( layers.ConvNdKernel1Layer(self.fno_width, self.fno_width) )
[docs] def forward( self, x: Float[Tensor, "B C_in X Y Z T"] ) -> Float[Tensor, "B C_latent X Y Z T"]: r"""Forward pass of the 4D FNO encoder.""" # Input validation: single check for ndim and channels if not torch.compiler.is_compiling(): if x.ndim != 6 or x.shape[1] != self._input_channels: raise ValueError( f"Expected 6D input (B, {self._input_channels}, X, Y, Z, T), " f"got {x.ndim}D tensor with shape {tuple(x.shape)}" ) # Add coordinate features if enabled if self.coord_features: coord_feat = self._meshgrid(list(x.shape), x.device) x = torch.cat((x, coord_feat), dim=1) # Lift input to latent space x = self.lift_network(x) # Apply padding for spectral convolution x = F.pad( x, (0, self.pad[3], 0, self.pad[2], 0, self.pad[1], 0, self.pad[0]), mode=self.padding_type, ) # Apply spectral convolution layers for k, conv_w in enumerate(zip(self.conv_layers, self.spconv_layers)): conv, w = conv_w if k < len(self.conv_layers) - 1: x = self.activation_fn(conv(x) + w(x)) else: x = conv(x) + w(x) # Remove padding x = x[..., : self.ipad[0], : self.ipad[1], : self.ipad[2], : self.ipad[3]] return x
def _meshgrid(self, shape: List[int], device: torch.device) -> Tensor: r"""Create 4D meshgrid feature. Parameters ---------- shape : List[int] Tensor shape as ``[batch, channels, x, y, z, t]``. device : torch.device Device model is on. Returns ------- Tensor Meshgrid tensor of shape :math:`(B, 4, X, Y, Z, T)`. """ bsize, size_x, size_y, size_z, size_t = ( shape[0], shape[2], shape[3], shape[4], shape[5], ) grid_x = torch.linspace(0, 1, size_x, dtype=torch.float32, device=device) grid_y = torch.linspace(0, 1, size_y, dtype=torch.float32, device=device) grid_z = torch.linspace(0, 1, size_z, dtype=torch.float32, device=device) grid_t = torch.linspace(0, 1, size_t, dtype=torch.float32, device=device) grid_x, grid_y, grid_z, grid_t = torch.meshgrid( grid_x, grid_y, grid_z, grid_t, indexing="ij" ) grid_x = grid_x.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1, 1) grid_y = grid_y.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1, 1) grid_z = grid_z.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1, 1) grid_t = grid_t.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1, 1) return torch.cat((grid_x, grid_y, grid_z, grid_t), dim=1)
[docs] def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: r"""Convert from grid-based (image) to point-based representation. Parameters ---------- value : Tensor Grid tensor of shape :math:`(B, C, X, Y, Z, T)`. Returns ------- Tuple[Tensor, List[int]] Tuple of (flattened tensor, original shape). """ y_shape = list(value.size()) output = torch.permute(value, (0, 2, 3, 4, 5, 1)) return output.reshape(-1, output.size(-1)), y_shape
[docs] def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: r"""Convert from point-based to grid-based (image) representation. Parameters ---------- value : Tensor Point tensor of shape :math:`(B \times X \times Y \times Z \times T, C)`. shape : List[int] Original grid shape as ``[B, C, X, Y, Z, T]``. Returns ------- Tensor Grid tensor of shape :math:`(B, C, X, Y, Z, T)`. """ output = value.reshape( shape[0], shape[2], shape[3], shape[4], shape[5], value.size(-1) ) return torch.permute(output, (0, 5, 1, 2, 3, 4))