deeplearning/modulus/modulus-core-v030/_modules/modulus/models/fno/fno.html
Source code for modulus.models.fno.fno
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import modulus.models.layers as layers
import modulus
from typing import Dict, List, Union, Tuple
from torch import Tensor
from dataclasses import dataclass
from ..meta import ModelMetaData
from ..module import Module
from ..mlp import FullyConnected
# ===================================================================
# ===================================================================
# 1D FNO
# ===================================================================
# ===================================================================
[docs]class FNO1DEncoder(nn.Module):
"""1D Spectral encoder for FNO
Parameters
----------
in_channels : int, optional
Number of input channels, by default 1
num_fno_layers : int, optional
Number of spectral convolutional layers, by default 4
fno_layer_size : int, optional
Latent features size in spectral convolutions, by default 32
num_fno_modes : Union[int, List[int]], optional
Number of Fourier modes kept in spectral convolutions, by default 16
padding : Union[int, List[int]], optional
Domain padding for spectral convolutions, by default 8
padding_type : str, optional
Type of padding for spectral convolutions, by default "constant"
activation_fn : nn.Module, optional
Activation function, by default nn.GELU
coord_features : bool, optional
Use coordinate grid as additional feature map, by default True
"""
def __init__(
self,
in_channels: int = 1,
num_fno_layers: int = 4,
fno_layer_size: int = 32,
num_fno_modes: Union[int, List[int]] = 16,
padding: Union[int, List[int]] = 8,
padding_type: str = "constant",
activation_fn: nn.Module = nn.GELU(),
coord_features: bool = True,
) -> None:
super().__init__()
self.in_channels = in_channels
self.num_fno_layers = num_fno_layers
self.fno_width = fno_layer_size
self.coord_features = coord_features
# Spectral modes to have weights
if isinstance(num_fno_modes, int):
num_fno_modes = [num_fno_modes]
# Add relative coordinate feature
if self.coord_features:
self.in_channels = self.in_channels + 1
self.activation_fn = activation_fn
self.spconv_layers = nn.ModuleList()
self.conv_layers = nn.ModuleList()
# Initial lift network
self.lift_network = torch.nn.Sequential()
self.lift_network.append(
layers.Conv1dFCLayer(self.in_channels, int(self.fno_width / 2))
)
self.lift_network.append(self.activation_fn)
self.lift_network.append(
layers.Conv1dFCLayer(int(self.fno_width / 2), self.fno_width)
)
# Build Neural Fourier Operators
for _ in range(self.num_fno_layers):
self.spconv_layers.append(
layers.SpectralConv1d(self.fno_width, self.fno_width, num_fno_modes[0])
)
self.conv_layers.append(nn.Conv1d(self.fno_width, self.fno_width, 1))
# Padding values for spectral conv
if isinstance(padding, int):
padding = [padding]
self.pad = padding[:1]
self.ipad = [-pad if pad > 0 else None for pad in self.pad]
self.padding_type = padding_type
[docs] def forward(self, x: Tensor) -> Tensor:
if self.coord_features:
coord_feat = self.meshgrid(list(x.shape), x.device)
x = torch.cat((x, coord_feat), dim=1)
x = self.lift_network(x)
# (left, right)
x = F.pad(x, (0, self.pad[0]), mode=self.padding_type)
# Spectral layers
for k, conv_w in enumerate(zip(self.conv_layers, self.spconv_layers)):
conv, w = conv_w
if k < len(self.conv_layers) - 1:
x = self.activation_fn(conv(x) + w(x))
else:
x = conv(x) + w(x)
x = x[..., : self.ipad[0]]
return x
[docs] def meshgrid(self, shape: List[int], device: torch.device) -> Tensor:
"""Creates 1D meshgrid feature
Parameters
----------
shape : List[int]
Tensor shape
device : torch.device
Device model is on
Returns
-------
Tensor
Meshgrid tensor
"""
bsize, size_x = shape[0], shape[2]
grid_x = torch.linspace(0, 1, size_x, dtype=torch.float32, device=device)
grid_x = grid_x.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1)
return grid_x# ===================================================================
# ===================================================================
# 2D FNO
# ===================================================================
# ===================================================================
[docs]class FNO2DEncoder(nn.Module):
"""2D Spectral encoder for FNO
Parameters
----------
in_channels : int, optional
Number of input channels, by default 1
num_fno_layers : int, optional
Number of spectral convolutional layers, by default 4
fno_layer_size : int, optional
Latent features size in spectral convolutions, by default 32
num_fno_modes : Union[int, List[int]], optional
Number of Fourier modes kept in spectral convolutions, by default 16
padding : Union[int, List[int]], optional
Domain padding for spectral convolutions, by default 8
padding_type : str, optional
Type of padding for spectral convolutions, by default "constant"
activation_fn : nn.Module, optional
Activation function, by default nn.GELU
coord_features : bool, optional
Use coordinate grid as additional feature map, by default True
"""
def __init__(
self,
in_channels: int = 1,
num_fno_layers: int = 4,
fno_layer_size: int = 32,
num_fno_modes: Union[int, List[int]] = 16,
padding: Union[int, List[int]] = 8,
padding_type: str = "constant",
activation_fn: nn.Module = nn.GELU(),
coord_features: bool = True,
) -> None:
super().__init__()
self.in_channels = in_channels
self.num_fno_layers = num_fno_layers
self.fno_width = fno_layer_size
self.coord_features = coord_features
# Spectral modes to have weights
if isinstance(num_fno_modes, int):
num_fno_modes = [num_fno_modes, num_fno_modes]
# Add relative coordinate feature
if self.coord_features:
self.in_channels = self.in_channels + 2
self.activation_fn = activation_fn
self.spconv_layers = nn.ModuleList()
self.conv_layers = nn.ModuleList()
# Initial lift network
self.lift_network = torch.nn.Sequential()
self.lift_network.append(
layers.Conv2dFCLayer(self.in_channels, int(self.fno_width / 2))
)
self.lift_network.append(self.activation_fn)
self.lift_network.append(
layers.Conv2dFCLayer(int(self.fno_width / 2), self.fno_width)
)
# Build Neural Fourier Operators
for _ in range(self.num_fno_layers):
self.spconv_layers.append(
layers.SpectralConv2d(
self.fno_width, self.fno_width, num_fno_modes[0], num_fno_modes[1]
)
)
self.conv_layers.append(nn.Conv2d(self.fno_width, self.fno_width, 1))
# Padding values for spectral conv
if isinstance(padding, int):
padding = [padding, padding]
padding = padding + [0, 0] # Pad with zeros for smaller lists
self.pad = padding[:2]
self.ipad = [-pad if pad > 0 else None for pad in self.pad]
self.padding_type = padding_type
[docs] def forward(self, x: Tensor) -> Tensor:
assert (
x.dim() == 4
), "Only 4D tensors [batch, in_channels, grid_x, grid_y] accepted for 2D FNO"
if self.coord_features:
coord_feat = self.meshgrid(list(x.shape), x.device)
x = torch.cat((x, coord_feat), dim=1)
x = self.lift_network(x)
# (left, right, top, bottom)
x = F.pad(x, (0, self.pad[1], 0, self.pad[0]), mode=self.padding_type)
# Spectral layers
for k, conv_w in enumerate(zip(self.conv_layers, self.spconv_layers)):
conv, w = conv_w
if k < len(self.conv_layers) - 1:
x = self.activation_fn(conv(x) + w(x))
else:
x = conv(x) + w(x)
# remove padding
x = x[..., : self.ipad[0], : self.ipad[1]]
return x
[docs] def meshgrid(self, shape: List[int], device: torch.device) -> Tensor:
"""Creates 2D meshgrid feature
Parameters
----------
shape : List[int]
Tensor shape
device : torch.device
Device model is on
Returns
-------
Tensor
Meshgrid tensor
"""
bsize, size_x, size_y = shape[0], shape[2], shape[3]
grid_x = torch.linspace(0, 1, size_x, dtype=torch.float32, device=device)
grid_y = torch.linspace(0, 1, size_y, dtype=torch.float32, device=device)
grid_x, grid_y = torch.meshgrid(grid_x, grid_y, indexing="ij")
grid_x = grid_x.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1)
grid_y = grid_y.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1)
return torch.cat((grid_x, grid_y), dim=1)# ===================================================================
# ===================================================================
# 3D FNO
# ===================================================================
# ===================================================================
[docs]class FNO3DEncoder(nn.Module):
"""3D Spectral encoder for FNO
Parameters
----------
in_channels : int, optional
Number of input channels, by default 1
num_fno_layers : int, optional
Number of spectral convolutional layers, by default 4
fno_layer_size : int, optional
Latent features size in spectral convolutions, by default 32
num_fno_modes : Union[int, List[int]], optional
Number of Fourier modes kept in spectral convolutions, by default 16
padding : Union[int, List[int]], optional
Domain padding for spectral convolutions, by default 8
padding_type : str, optional
Type of padding for spectral convolutions, by default "constant"
activation_fn : nn.Module, optional
Activation function, by default nn.GELU
coord_features : bool, optional
Use coordinate grid as additional feature map, by default True
"""
def __init__(
self,
in_channels: int = 1,
num_fno_layers: int = 4,
fno_layer_size: int = 32,
num_fno_modes: Union[int, List[int]] = 16,
padding: Union[int, List[int]] = 8,
padding_type: str = "constant",
activation_fn: nn.Module = nn.GELU(),
coord_features: bool = True,
) -> None:
super().__init__()
self.in_channels = in_channels
self.num_fno_layers = num_fno_layers
self.fno_width = fno_layer_size
self.coord_features = coord_features
# Spectral modes to have weights
if isinstance(num_fno_modes, int):
num_fno_modes = [num_fno_modes, num_fno_modes, num_fno_modes]
# Add relative coordinate feature
if self.coord_features:
self.in_channels = self.in_channels + 3
self.activation_fn = activation_fn
self.spconv_layers = nn.ModuleList()
self.conv_layers = nn.ModuleList()
# Initial lift network
self.lift_network = torch.nn.Sequential()
self.lift_network.append(
layers.Conv3dFCLayer(self.in_channels, int(self.fno_width / 2))
)
self.lift_network.append(self.activation_fn)
self.lift_network.append(
layers.Conv3dFCLayer(int(self.fno_width / 2), self.fno_width)
)
# Build Neural Fourier Operators
for _ in range(self.num_fno_layers):
self.spconv_layers.append(
layers.SpectralConv3d(
self.fno_width,
self.fno_width,
num_fno_modes[0],
num_fno_modes[1],
num_fno_modes[2],
)
)
self.conv_layers.append(nn.Conv3d(self.fno_width, self.fno_width, 1))
# Padding values for spectral conv
if isinstance(padding, int):
padding = [padding, padding, padding]
padding = padding + [0, 0, 0] # Pad with zeros for smaller lists
self.pad = padding[:3]
self.ipad = [-pad if pad > 0 else None for pad in self.pad]
self.padding_type = padding_type
[docs] def forward(self, x: Tensor) -> Tensor:
if self.coord_features:
coord_feat = self.meshgrid(list(x.shape), x.device)
x = torch.cat((x, coord_feat), dim=1)
x = self.lift_network(x)
# (left, right, top, bottom, front, back)
x = F.pad(
x,
(0, self.pad[2], 0, self.pad[1], 0, self.pad[0]),
mode=self.padding_type,
)
# Spectral layers
for k, conv_w in enumerate(zip(self.conv_layers, self.spconv_layers)):
conv, w = conv_w
if k < len(self.conv_layers) - 1:
x = self.activation_fn(conv(x) + w(x))
else:
x = conv(x) + w(x)
x = x[..., : self.ipad[0], : self.ipad[1], : self.ipad[2]]
return x
[docs] def meshgrid(self, shape: List[int], device: torch.device) -> Tensor:
"""Creates 3D meshgrid feature
Parameters
----------
shape : List[int]
Tensor shape
device : torch.device
Device model is on
Returns
-------
Tensor
Meshgrid tensor
"""
bsize, size_x, size_y, size_z = shape[0], shape[2], shape[3], shape[4]
grid_x = torch.linspace(0, 1, size_x, dtype=torch.float32, device=device)
grid_y = torch.linspace(0, 1, size_y, dtype=torch.float32, device=device)
grid_z = torch.linspace(0, 1, size_z, dtype=torch.float32, device=device)
grid_x, grid_y, grid_z = torch.meshgrid(grid_x, grid_y, grid_z, indexing="ij")
grid_x = grid_x.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1)
grid_y = grid_y.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1)
grid_z = grid_z.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1)
return torch.cat((grid_x, grid_y, grid_z), dim=1)# ===================================================================
# ===================================================================
# 4D FNO
# ===================================================================
# ===================================================================
[docs]class FNO4DEncoder(nn.Module):
"""4D Spectral encoder for FNO
Parameters
----------
in_channels : int, optional
Number of input channels, by default 1
num_fno_layers : int, optional
Number of spectral convolutional layers, by default 4
fno_layer_size : int, optional
Latent features size in spectral convolutions, by default 32
num_fno_modes : Union[int, List[int]], optional
Number of Fourier modes kept in spectral convolutions, by default 16
padding : Union[int, List[int]], optional
Domain padding for spectral convolutions, by default 8
padding_type : str, optional
Type of padding for spectral convolutions, by default "constant"
activation_fn : nn.Module, optional
Activation function, by default nn.GELU
coord_features : bool, optional
Use coordinate grid as additional feature map, by default True
"""
def __init__(
self,
in_channels: int = 1,
num_fno_layers: int = 4,
fno_layer_size: int = 32,
num_fno_modes: Union[int, List[int]] = 16,
padding: Union[int, List[int]] = 8,
padding_type: str = "constant",
activation_fn: nn.Module = nn.GELU(),
coord_features: bool = True,
) -> None:
super().__init__()
self.in_channels = in_channels
self.num_fno_layers = num_fno_layers
self.fno_width = fno_layer_size
self.coord_features = coord_features
# Spectral modes to have weights
if isinstance(num_fno_modes, int):
num_fno_modes = [num_fno_modes, num_fno_modes, num_fno_modes, num_fno_modes]
# Add relative coordinate feature
if self.coord_features:
self.in_channels = self.in_channels + 4
self.activation_fn = activation_fn
self.spconv_layers = nn.ModuleList()
self.conv_layers = nn.ModuleList()
# Initial lift network
self.lift_network = torch.nn.Sequential()
self.lift_network.append(
layers.ConvNdFCLayer(self.in_channels, int(self.fno_width / 2))
)
self.lift_network.append(self.activation_fn)
self.lift_network.append(
layers.ConvNdFCLayer(int(self.fno_width / 2), self.fno_width)
)
# Build Neural Fourier Operators
for _ in range(self.num_fno_layers):
self.spconv_layers.append(
layers.SpectralConv4d(
self.fno_width,
self.fno_width,
num_fno_modes[0],
num_fno_modes[1],
num_fno_modes[2],
num_fno_modes[3],
)
)
self.conv_layers.append(
layers.ConvNdKernel1Layer(self.fno_width, self.fno_width)
)
# Padding values for spectral conv
if isinstance(padding, int):
padding = [padding, padding, padding, padding]
padding = padding + [0, 0, 0, 0] # Pad with zeros for smaller lists
self.pad = padding[:4]
self.ipad = [-pad if pad > 0 else None for pad in self.pad]
self.padding_type = padding_type
[docs] def forward(self, x: Tensor) -> Tensor:
if self.coord_features:
coord_feat = self.meshgrid(list(x.shape), x.device)
x = torch.cat((x, coord_feat), dim=1)
x = self.lift_network(x)
# (left, right, top, bottom, front, back, past, future)
x = F.pad(
x,
(0, self.pad[3], 0, self.pad[2], 0, self.pad[1], 0, self.pad[0]),
mode=self.padding_type,
)
# Spectral layers
for k, conv_w in enumerate(zip(self.conv_layers, self.spconv_layers)):
conv, w = conv_w
if k < len(self.conv_layers) - 1:
x = self.activation_fn(conv(x) + w(x))
else:
x = conv(x) + w(x)
x = x[..., : self.ipad[0], : self.ipad[1], : self.ipad[2], : self.ipad[3]]
return x
[docs] def meshgrid(self, shape: List[int], device: torch.device) -> Tensor:
"""Creates 4D meshgrid feature
Parameters
----------
shape : List[int]
Tensor shape
device : torch.device
Device model is on
Returns
-------
Tensor
Meshgrid tensor
"""
bsize, size_x, size_y, size_z, size_t = (
shape[0],
shape[2],
shape[3],
shape[4],
shape[5],
)
grid_x = torch.linspace(0, 1, size_x, dtype=torch.float32, device=device)
grid_y = torch.linspace(0, 1, size_y, dtype=torch.float32, device=device)
grid_z = torch.linspace(0, 1, size_z, dtype=torch.float32, device=device)
grid_t = torch.linspace(0, 1, size_t, dtype=torch.float32, device=device)
grid_x, grid_y, grid_z, grid_t = torch.meshgrid(
grid_x, grid_y, grid_z, grid_t, indexing="ij"
)
grid_x = grid_x.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1, 1)
grid_y = grid_y.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1, 1)
grid_z = grid_z.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1, 1)
grid_t = grid_t.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1, 1)
return torch.cat((grid_x, grid_y, grid_z, grid_t), dim=1)# Functions for converting between point based and grid (image) representations
def _grid_to_points1d(value: Tensor) -> Tuple[Tensor, List[int]]:
y_shape = list(value.size())
output = torch.permute(value, (0, 2, 1))
return output.reshape(-1, output.size(-1)), y_shape
def _points_to_grid1d(value: Tensor, shape: List[int]) -> Tensor:
output = value.reshape(shape[0], shape[2], value.size(-1))
return torch.permute(output, (0, 2, 1))
def _grid_to_points2d(value: Tensor) -> Tuple[Tensor, List[int]]:
y_shape = list(value.size())
output = torch.permute(value, (0, 2, 3, 1))
return output.reshape(-1, output.size(-1)), y_shape
def _points_to_grid2d(value: Tensor, shape: List[int]) -> Tensor:
output = value.reshape(shape[0], shape[2], shape[3], value.size(-1))
return torch.permute(output, (0, 3, 1, 2))
def _grid_to_points3d(value: Tensor) -> Tuple[Tensor, List[int]]:
y_shape = list(value.size())
output = torch.permute(value, (0, 2, 3, 4, 1))
return output.reshape(-1, output.size(-1)), y_shape
def _points_to_grid3d(value: Tensor, shape: List[int]) -> Tensor:
output = value.reshape(shape[0], shape[2], shape[3], shape[4], value.size(-1))
return torch.permute(output, (0, 4, 1, 2, 3))
def _grid_to_points4d(value: Tensor) -> Tuple[Tensor, List[int]]:
y_shape = list(value.size())
output = torch.permute(value, (0, 2, 3, 4, 5, 1))
return output.reshape(-1, output.size(-1)), y_shape
def _points_to_grid4d(value: Tensor, shape: List[int]) -> Tensor:
output = value.reshape(
shape[0], shape[2], shape[3], shape[4], shape[5], value.size(-1)
)
return torch.permute(output, (0, 5, 1, 2, 3, 4))
# ===================================================================
# ===================================================================
# General FNO Model
# ===================================================================
# ===================================================================
[docs]@dataclass
class MetaData(ModelMetaData):
name: str = "FourierNeuralOperator"
# Optimization
jit: bool = True
cuda_graphs: bool = True
amp: bool = False
# Inference
onnx_cpu: bool = False
onnx_gpu: bool = False
onnx_runtime: bool = False
# Physics informed
var_dim: int = 1
func_torch: bool = False
auto_grad: bool = False
[docs]class FNO(Module):
"""Fourier neural operator (FNO) model.
Note
----
The FNO architecture supports options for 1D, 2D, 3D and 4D fields which can
be controlled using the `dimension` parameter.
Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
decoder_layers : int, optional
Number of decoder layers, by default 1
decoder_layer_size : int, optional
Number of neurons in decoder layers, by default 32
decoder_activation_fn : str, optional
Activation function for decoder, by default "silu"
dimension : int
Model dimensionality (supports 1, 2, 3).
latent_channels : int, optional
Latent features size in spectral convolutions, by default 32
num_fno_layers : int, optional
Number of spectral convolutional layers, by default 4
num_fno_modes : Union[int, List[int]], optional
Number of Fourier modes kept in spectral convolutions, by default 16
padding : int, optional
Domain padding for spectral convolutions, by default 8
padding_type : str, optional
Type of padding for spectral convolutions, by default "constant"
activation_fn : str, optional
Activation function, by default "gelu"
coord_features : bool, optional
Use coordinate grid as additional feature map, by default True
Example
-------
>>> # define the 2d FNO model
>>> model = modulus.models.fno.FNO(
... in_channels=4,
... out_channels=3,
... decoder_layers=2,
... decoder_layer_size=32,
... dimension=2,
... latent_channels=32,
... num_fno_layers=2,
... padding=0,
... )
>>> input = torch.randn(32, 4, 32, 32) #(N, C, H, W)
>>> output = model(input)
>>> output.size()
torch.Size([32, 3, 32, 32])
Note
----
Reference: Li, Zongyi, et al. "Fourier neural operator for parametric
partial differential equations." arXiv preprint arXiv:2010.08895 (2020).
"""
def __init__(
self,
in_channels: int,
out_channels: int,
decoder_layers: int = 1,
decoder_layer_size: int = 32,
decoder_activation_fn: str = "silu",
dimension: int = 2,
latent_channels: int = 32,
num_fno_layers: int = 4,
num_fno_modes: Union[int, List[int]] = 16,
padding: int = 8,
padding_type: str = "constant",
activation_fn: str = "gelu",
coord_features: bool = True,
) -> None:
super().__init__(meta=MetaData())
self.num_fno_layers = num_fno_layers
self.num_fno_modes = num_fno_modes
self.padding = padding
self.padding_type = padding_type
self.activation_fn = layers.get_activation(activation_fn)
self.coord_features = coord_features
# decoder net
self.decoder_net = FullyConnected(
in_features=latent_channels,
layer_size=decoder_layer_size,
out_features=out_channels,
num_layers=decoder_layers,
activation_fn=decoder_activation_fn,
)
if dimension == 1:
FNOModel = FNO1DEncoder
self.grid_to_points = _grid_to_points1d # For JIT
self.points_to_grid = _points_to_grid1d # For JIT
elif dimension == 2:
FNOModel = FNO2DEncoder
self.grid_to_points = _grid_to_points2d # For JIT
self.points_to_grid = _points_to_grid2d # For JIT
elif dimension == 3:
FNOModel = FNO3DEncoder
self.grid_to_points = _grid_to_points3d # For JIT
self.points_to_grid = _points_to_grid3d # For JIT
elif dimension == 4:
FNOModel = FNO4DEncoder
self.grid_to_points = _grid_to_points4d # For JIT
self.points_to_grid = _points_to_grid4d # For JIT
else:
raise NotImplementedError(
"Invalid dimensionality. Only 1D, 2D, 3D and 4D FNO implemented"
)
self.spec_encoder = FNOModel(
in_channels,
num_fno_layers=self.num_fno_layers,
fno_layer_size=latent_channels,
num_fno_modes=self.num_fno_modes,
padding=self.padding,
padding_type=self.padding_type,
activation_fn=self.activation_fn,
coord_features=self.coord_features,
)
[docs] def forward(self, x: Tensor) -> Tensor:
# Fourier encoder
y_latent = self.spec_encoder(x)
# Reshape to pointwise inputs if not a conv FC model
y_shape = y_latent.shape
y_latent, y_shape = self.grid_to_points(y_latent)
# Decoder
y = self.decoder_net(y_latent)
# Convert back into grid
y = self.points_to_grid(y, y_shape)
return y