# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Complex, Float
from torch import Tensor
[docs]
class SpectralConv1d(nn.Module):
"""1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
modes1 : int
Number of Fourier modes to multiply, at most floor(N/2) + 1
"""
def __init__(self, in_channels: int, out_channels: int, modes1: int):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.modes1 = (
modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1
)
self.scale = 1 / (in_channels * out_channels)
self.weights1 = nn.Parameter(
torch.empty(in_channels, out_channels, self.modes1, 2)
)
self.reset_parameters()
[docs]
def compl_mul1d(
self,
input: Complex[Tensor, "batch in_channels modes"],
weights: Float[Tensor, "in_channels out_channels modes 2"],
) -> Complex[Tensor, "batch out_channels modes"]:
"""Complex multiplication
Parameters
----------
input : Tensor
Input tensor
weights : Tensor
Weights tensor
Returns
-------
Tensor
Product of complex multiplication
"""
# (batch, in_channels, modes), (in_channels, out_channels, modes) -> (batch, out_channels, modes)
cweights = torch.view_as_complex(weights)
return torch.einsum("bix,iox->box", input, cweights)
[docs]
def forward(
self, x: Float[Tensor, "batch in_channels x"]
) -> Float[Tensor, "batch out_channels x"]:
bsize = x.shape[0]
x_ft = torch.fft.rfft(x) # (batch, in_channels, x//2+1) complex
# Multiply relevant Fourier modes
out_ft = torch.zeros(
bsize,
self.out_channels,
x.size(-1) // 2 + 1,
device=x.device,
dtype=torch.cfloat,
) # (batch, out_channels, x//2+1) complex
out_ft[:, :, : self.modes1] = self.compl_mul1d(
x_ft[:, :, : self.modes1],
self.weights1,
)
# Return to physical space
x = torch.fft.irfft(out_ft, n=x.size(-1)) # (batch, out_channels, x) real
return x
[docs]
def reset_parameters(self):
"""Reset spectral weights with distribution scale*U(0,1)"""
self.weights1.data = self.scale * torch.rand(self.weights1.data.shape)
[docs]
class SpectralConv2d(nn.Module):
"""2D Fourier layer. It does FFT, linear transform, and Inverse FFT.
Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
modes1 : int
Number of Fourier modes to multiply in first dimension, at most floor(N/2) + 1
modes2 : int
Number of Fourier modes to multiply in second dimension, at most floor(N/2) + 1
"""
def __init__(self, in_channels: int, out_channels: int, modes1: int, modes2: int):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.modes1 = (
modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1
)
self.modes2 = modes2
self.scale = 1 / (in_channels * out_channels)
self.weights1 = nn.Parameter(
torch.empty(in_channels, out_channels, self.modes1, self.modes2, 2)
)
self.weights2 = nn.Parameter(
torch.empty(in_channels, out_channels, self.modes1, self.modes2, 2)
)
self.reset_parameters()
[docs]
def compl_mul2d(
self,
input: Complex[Tensor, "batch in_channels modes1 modes2"],
weights: Float[Tensor, "in_channels out_channels modes1 modes2 2"],
) -> Complex[Tensor, "batch out_channels modes1 modes2"]:
"""Complex multiplication
Parameters
----------
input : Tensor
Input tensor
weights : Tensor
Weights tensor
Returns
-------
Tensor
Product of complex multiplication
"""
# (batch, in_channels, modes1, modes2), (in_channels, out_channels, modes1, modes2)
# -> (batch, out_channels, modes1, modes2)
cweights = torch.view_as_complex(weights)
return torch.einsum("bixy,ioxy->boxy", input, cweights)
[docs]
def forward(
self, x: Float[Tensor, "batch in_channels h w"]
) -> Float[Tensor, "batch out_channels h w"]:
x_ft = torch.fft.rfft2(x) # (batch, in_channels, h, w//2+1) complex
h, w = x_ft.size(-2), x_ft.size(-1) # h=h, w=w//2+1
# Initialize output in frequency space
out_ft = torch.zeros(
x.size(0), self.out_channels, h, w, dtype=torch.cfloat, device=x.device
) # (batch, out_channels, h, w) complex
# Accumulate Fourier modes. Use .contiguous() on sliced complex tensors and
# padding (not slice assignment) for torch.compile compatibility.
# Slice assignment causes gradient stride issues in the Inductor backward pass.
# Pad format: (left, right, top, bottom) for last 2 dims
out_ft = out_ft + F.pad(
self.compl_mul2d(
x_ft[:, :, : self.modes1, : self.modes2].contiguous(), self.weights1
),
(0, w - self.modes2, 0, h - self.modes1),
)
out_ft = out_ft + F.pad(
self.compl_mul2d(
x_ft[:, :, -self.modes1 :, : self.modes2].contiguous(), self.weights2
),
(0, w - self.modes2, h - self.modes1, 0),
)
# Return to physical space
return torch.fft.irfft2(
out_ft, s=(x.size(-2), x.size(-1))
) # (batch, out_channels, h, w) real
[docs]
def reset_parameters(self):
"""Reset spectral weights with distribution scale*U(0,1)"""
self.weights1.data = self.scale * torch.rand(self.weights1.data.shape)
self.weights2.data = self.scale * torch.rand(self.weights2.data.shape)
[docs]
class SpectralConv3d(nn.Module):
"""3D Fourier layer. It does FFT, linear transform, and Inverse FFT.
Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
modes1 : int
Number of Fourier modes to multiply in first dimension, at most floor(N/2) + 1
modes2 : int
Number of Fourier modes to multiply in second dimension, at most floor(N/2) + 1
modes3 : int
Number of Fourier modes to multiply in third dimension, at most floor(N/2) + 1
"""
def __init__(
self, in_channels: int, out_channels: int, modes1: int, modes2: int, modes3: int
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.modes1 = (
modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1
)
self.modes2 = modes2
self.modes3 = modes3
self.scale = 1 / (in_channels * out_channels)
self.weights1 = nn.Parameter(
torch.empty(
in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2
)
)
self.weights2 = nn.Parameter(
torch.empty(
in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2
)
)
self.weights3 = nn.Parameter(
torch.empty(
in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2
)
)
self.weights4 = nn.Parameter(
torch.empty(
in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2
)
)
self.reset_parameters()
[docs]
def compl_mul3d(
self,
input: Complex[Tensor, "batch in_channels modes1 modes2 modes3"],
weights: Float[Tensor, "in_channels out_channels modes1 modes2 modes3 2"],
) -> Complex[Tensor, "batch out_channels modes1 modes2 modes3"]:
"""Complex multiplication
Parameters
----------
input : Tensor
Input tensor
weights : Tensor
Weights tensor
Returns
-------
Tensor
Product of complex multiplication
"""
# (batch, in_channels, modes1, modes2, modes3),
# (in_channels, out_channels, modes1, modes2, modes3)
# -> (batch, out_channels, modes1, modes2, modes3)
cweights = torch.view_as_complex(weights)
return torch.einsum("bixyz,ioxyz->boxyz", input, cweights)
[docs]
def forward(
self, x: Float[Tensor, "batch in_channels d1 d2 d3"]
) -> Float[Tensor, "batch out_channels d1 d2 d3"]:
x_ft = torch.fft.rfftn(
x, dim=[-3, -2, -1]
) # (batch, in_channels, d1, d2, d3//2+1) complex
d1, d2, d3 = x_ft.size(-3), x_ft.size(-2), x_ft.size(-1) # d3 = d3//2+1
# Initialize output in frequency space
out_ft = torch.zeros(
x.size(0),
self.out_channels,
d1,
d2,
d3,
dtype=torch.cfloat,
device=x.device,
) # (batch, out_channels, d1, d2, d3) complex
# Accumulate Fourier modes. Use .contiguous() on sliced complex tensors and
# padding (not slice assignment) for torch.compile compatibility.
# Slice assignment causes gradient stride issues in the Inductor backward pass.
# Pad format for 3D: (d3_left, d3_right, d2_top, d2_bottom, d1_front, d1_back)
pad_d3 = d3 - self.modes3
pad_d2 = d2 - self.modes2
pad_d1 = d1 - self.modes1
out_ft = out_ft + F.pad(
self.compl_mul3d(
x_ft[:, :, : self.modes1, : self.modes2, : self.modes3].contiguous(),
self.weights1,
),
(0, pad_d3, 0, pad_d2, 0, pad_d1),
)
out_ft = out_ft + F.pad(
self.compl_mul3d(
x_ft[:, :, -self.modes1 :, : self.modes2, : self.modes3].contiguous(),
self.weights2,
),
(0, pad_d3, 0, pad_d2, pad_d1, 0),
)
out_ft = out_ft + F.pad(
self.compl_mul3d(
x_ft[:, :, : self.modes1, -self.modes2 :, : self.modes3].contiguous(),
self.weights3,
),
(0, pad_d3, pad_d2, 0, 0, pad_d1),
)
out_ft = out_ft + F.pad(
self.compl_mul3d(
x_ft[:, :, -self.modes1 :, -self.modes2 :, : self.modes3].contiguous(),
self.weights4,
),
(0, pad_d3, pad_d2, 0, pad_d1, 0),
)
# Return to physical space
return torch.fft.irfftn(
out_ft, s=(x.size(-3), x.size(-2), x.size(-1))
) # (batch, out_channels, d1, d2, d3) real
[docs]
def reset_parameters(self):
"""Reset spectral weights with distribution scale*U(0,1)"""
self.weights1.data = self.scale * torch.rand(self.weights1.data.shape)
self.weights2.data = self.scale * torch.rand(self.weights2.data.shape)
self.weights3.data = self.scale * torch.rand(self.weights3.data.shape)
self.weights4.data = self.scale * torch.rand(self.weights4.data.shape)
[docs]
class SpectralConv4d(nn.Module):
"""4D Fourier layer. It does FFT, linear transform, and Inverse FFT.
Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
modes1 : int
Number of Fourier modes to multiply in first dimension, at most floor(N/2) + 1
modes2 : int
Number of Fourier modes to multiply in second dimension, at most floor(N/2) + 1
modes3 : int
Number of Fourier modes to multiply in third dimension, at most floor(N/2) + 1
"""
def __init__(
self,
in_channels: int,
out_channels: int,
modes1: int,
modes2: int,
modes3: int,
modes4: int,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
# Number of Fourier modes to multiply, at most floor(N/2) + 1
self.modes1 = modes1
self.modes2 = modes2
self.modes3 = modes3
self.modes4 = modes4
self.scale = 1 / (in_channels * out_channels)
self.weights1 = nn.Parameter(
torch.empty(
in_channels,
out_channels,
self.modes1,
self.modes2,
self.modes3,
self.modes4,
2,
)
)
self.weights2 = nn.Parameter(
torch.empty(
in_channels,
out_channels,
self.modes1,
self.modes2,
self.modes3,
self.modes4,
2,
)
)
self.weights3 = nn.Parameter(
torch.empty(
in_channels,
out_channels,
self.modes1,
self.modes2,
self.modes3,
self.modes4,
2,
)
)
self.weights4 = nn.Parameter(
torch.empty(
in_channels,
out_channels,
self.modes1,
self.modes2,
self.modes3,
self.modes4,
2,
)
)
self.weights5 = nn.Parameter(
torch.empty(
in_channels,
out_channels,
self.modes1,
self.modes2,
self.modes3,
self.modes4,
2,
)
)
self.weights6 = nn.Parameter(
torch.empty(
in_channels,
out_channels,
self.modes1,
self.modes2,
self.modes3,
self.modes4,
2,
)
)
self.weights7 = nn.Parameter(
torch.empty(
in_channels,
out_channels,
self.modes1,
self.modes2,
self.modes3,
self.modes4,
2,
)
)
self.weights8 = nn.Parameter(
torch.empty(
in_channels,
out_channels,
self.modes1,
self.modes2,
self.modes3,
self.modes4,
2,
)
)
self.reset_parameters()
[docs]
def compl_mul4d(
self,
input: Complex[Tensor, "batch in_channels modes1 modes2 modes3 modes4"],
weights: Float[
Tensor, "in_channels out_channels modes1 modes2 modes3 modes4 2"
],
) -> Complex[Tensor, "batch out_channels modes1 modes2 modes3 modes4"]:
"""Complex multiplication
Parameters
----------
input : Tensor
Input tensor
weights : Tensor
Weights tensor
Returns
-------
Tensor
Product of complex multiplication
"""
# (batch, in_channels, modes1, modes2, modes3, modes4),
# (in_channels, out_channels, modes1, modes2, modes3, modes4)
# -> (batch, out_channels, modes1, modes2, modes3, modes4)
cweights = torch.view_as_complex(weights)
return torch.einsum("bixyzt,ioxyzt->boxyzt", input, cweights)
[docs]
def forward(
self, x: Float[Tensor, "batch in_channels d1 d2 d3 d4"]
) -> Float[Tensor, "batch out_channels d1 d2 d3 d4"]:
x_ft = torch.fft.rfftn(
x, dim=[-4, -3, -2, -1]
) # (batch, in_channels, d1, d2, d3, d4//2+1) complex
d1, d2, d3, d4 = (
x_ft.size(-4),
x_ft.size(-3),
x_ft.size(-2),
x_ft.size(-1),
) # d4 = d4//2+1
# Initialize output in frequency space
out_ft = torch.zeros(
x.size(0),
self.out_channels,
d1,
d2,
d3,
d4,
dtype=torch.cfloat,
device=x.device,
) # (batch, out_channels, d1, d2, d3, d4) complex
# Accumulate Fourier modes. Use .contiguous() on sliced complex tensors and
# padding (not slice assignment) for torch.compile compatibility.
# Slice assignment causes gradient stride issues in the Inductor backward pass.
# Pad format for 4D: (d4_left, d4_right, d3_front, d3_back, d2_top, d2_bottom, d1_near, d1_far)
pad_d4 = d4 - self.modes4
pad_d3 = d3 - self.modes3
pad_d2 = d2 - self.modes2
pad_d1 = d1 - self.modes1
# [:modes1, :modes2, :modes3, :modes4]
out_ft = out_ft + F.pad(
self.compl_mul4d(
x_ft[
:, :, : self.modes1, : self.modes2, : self.modes3, : self.modes4
].contiguous(),
self.weights1,
),
(0, pad_d4, 0, pad_d3, 0, pad_d2, 0, pad_d1),
)
# [-modes1:, :modes2, :modes3, :modes4]
out_ft = out_ft + F.pad(
self.compl_mul4d(
x_ft[
:, :, -self.modes1 :, : self.modes2, : self.modes3, : self.modes4
].contiguous(),
self.weights2,
),
(0, pad_d4, 0, pad_d3, 0, pad_d2, pad_d1, 0),
)
# [:modes1, -modes2:, :modes3, :modes4]
out_ft = out_ft + F.pad(
self.compl_mul4d(
x_ft[
:, :, : self.modes1, -self.modes2 :, : self.modes3, : self.modes4
].contiguous(),
self.weights3,
),
(0, pad_d4, 0, pad_d3, pad_d2, 0, 0, pad_d1),
)
# [:modes1, :modes2, -modes3:, :modes4]
out_ft = out_ft + F.pad(
self.compl_mul4d(
x_ft[
:, :, : self.modes1, : self.modes2, -self.modes3 :, : self.modes4
].contiguous(),
self.weights4,
),
(0, pad_d4, pad_d3, 0, 0, pad_d2, 0, pad_d1),
)
# [-modes1:, -modes2:, :modes3, :modes4]
out_ft = out_ft + F.pad(
self.compl_mul4d(
x_ft[
:, :, -self.modes1 :, -self.modes2 :, : self.modes3, : self.modes4
].contiguous(),
self.weights5,
),
(0, pad_d4, 0, pad_d3, pad_d2, 0, pad_d1, 0),
)
# [-modes1:, :modes2, -modes3:, :modes4]
out_ft = out_ft + F.pad(
self.compl_mul4d(
x_ft[
:, :, -self.modes1 :, : self.modes2, -self.modes3 :, : self.modes4
].contiguous(),
self.weights6,
),
(0, pad_d4, pad_d3, 0, 0, pad_d2, pad_d1, 0),
)
# [:modes1, -modes2:, -modes3:, :modes4]
out_ft = out_ft + F.pad(
self.compl_mul4d(
x_ft[
:, :, : self.modes1, -self.modes2 :, -self.modes3 :, : self.modes4
].contiguous(),
self.weights7,
),
(0, pad_d4, pad_d3, 0, pad_d2, 0, 0, pad_d1),
)
# [-modes1:, -modes2:, -modes3:, :modes4]
out_ft = out_ft + F.pad(
self.compl_mul4d(
x_ft[
:, :, -self.modes1 :, -self.modes2 :, -self.modes3 :, : self.modes4
].contiguous(),
self.weights8,
),
(0, pad_d4, pad_d3, 0, pad_d2, 0, pad_d1, 0),
)
# Return to physical space
return torch.fft.irfftn(
out_ft, s=(x.size(-4), x.size(-3), x.size(-2), x.size(-1))
) # (batch, out_channels, d1, d2, d3, d4) real
[docs]
def reset_parameters(self):
"""Reset spectral weights with distribution scale*U(0,1)"""
self.weights1.data = self.scale * torch.rand(self.weights1.data.shape)
self.weights2.data = self.scale * torch.rand(self.weights2.data.shape)
self.weights3.data = self.scale * torch.rand(self.weights3.data.shape)
self.weights4.data = self.scale * torch.rand(self.weights4.data.shape)
self.weights5.data = self.scale * torch.rand(self.weights5.data.shape)
self.weights6.data = self.scale * torch.rand(self.weights6.data.shape)
self.weights7.data = self.scale * torch.rand(self.weights7.data.shape)
self.weights8.data = self.scale * torch.rand(self.weights8.data.shape)
# ==========================================
# Utils for PINO exact gradients
# ==========================================
[docs]
def fourier_derivatives(x: Tensor, ell: List[float]) -> Tuple[Tensor, Tensor]:
"""
Fourier derivative function for PINO
"""
# check that input shape maches domain length
if len(x.shape) - 2 != len(ell):
raise ValueError("input shape doesn't match domain dims")
# set pi from numpy
pi = float(np.pi)
# get needed dims
n = x.shape[2:]
dim = len(ell)
# get device
device = x.device
# compute fourier transform
x_h = torch.fft.fftn(x, dim=list(range(2, dim + 2)))
# make wavenumbers
k_x = []
for i, nx in enumerate(n):
k_x.append(
torch.cat(
(
torch.arange(start=0, end=nx // 2, step=1, device=device),
torch.arange(start=-nx // 2, end=0, step=1, device=device),
),
0,
).reshape((i + 2) * [1] + [nx] + (dim - i - 1) * [1])
)
# compute laplacian in fourier space
j = torch.complex(
torch.tensor([0.0], device=device), torch.tensor([1.0], device=device)
) # Cuda graphs does not work here
wx_h = [j * k_x_i * x_h * (2 * pi / ell[i]) for i, k_x_i in enumerate(k_x)]
wxx_h = [
j * k_x_i * wx_h_i * (2 * pi / ell[i])
for i, (wx_h_i, k_x_i) in enumerate(zip(wx_h, k_x))
]
# inverse fourier transform out
wx = torch.cat(
[torch.fft.ifftn(wx_h_i, dim=list(range(2, dim + 2))).real for wx_h_i in wx_h],
dim=1,
)
wxx = torch.cat(
[
torch.fft.ifftn(wxx_h_i, dim=list(range(2, dim + 2))).real
for wxx_h_i in wxx_h
],
dim=1,
)
return (wx, wxx)
[docs]
def calc_latent_derivatives(
x: Tensor, domain_length: List[int] = 2
) -> Tuple[List[Tensor], List[Tensor]]:
"""
Compute first and second order derivatives of latent variables
"""
dim = len(x.shape) - 2
# Compute derivatives of latent variables via fourier methods
# Padd domain by factor of 2 for non-periodic domains
padd = [(i - 1) // 2 for i in list(x.shape[2:])]
# Scale domain length by padding amount
domain_length = [
domain_length[i] * (2 * padd[i] + x.shape[i + 2]) / x.shape[i + 2]
for i in range(dim)
]
padding = padd + padd
x_p = F.pad(x, padding, mode="replicate")
dx, ddx = fourier_derivatives(x_p, domain_length)
# Trim padded domain
if len(x.shape) == 3:
dx = dx[..., padd[0] : -padd[0]]
ddx = ddx[..., padd[0] : -padd[0]]
dx_list = torch.split(dx, x.shape[1], dim=1)
ddx_list = torch.split(ddx, x.shape[1], dim=1)
elif len(x.shape) == 4:
dx = dx[..., padd[0] : -padd[0], padd[1] : -padd[1]]
ddx = ddx[..., padd[0] : -padd[0], padd[1] : -padd[1]]
dx_list = torch.split(dx, x.shape[1], dim=1)
ddx_list = torch.split(ddx, x.shape[1], dim=1)
else:
dx = dx[..., padd[0] : -padd[0], padd[1] : -padd[1], padd[2] : -padd[2]]
ddx = ddx[..., padd[0] : -padd[0], padd[1] : -padd[1], padd[2] : -padd[2]]
dx_list = torch.split(dx, x.shape[1], dim=1)
ddx_list = torch.split(ddx, x.shape[1], dim=1)
return dx_list, ddx_list
[docs]
def first_order_pino_grads(
u: Tensor,
ux: List[Tensor],
weights_1: Tensor,
weights_2: Tensor,
bias_1: Tensor,
) -> Tuple[Tensor]: # pragma: no cover
"""
Compute first order derivatives of output variables
"""
# dim for einsum
dim = len(u.shape) - 2
dim_str = "xyz"[:dim]
# compute first order derivatives of input
# compute first layer
if dim == 1:
u_hidden = F.conv1d(u, weights_1, bias_1)
elif dim == 2:
weights_1 = weights_1.unsqueeze(-1)
weights_2 = weights_2.unsqueeze(-1)
u_hidden = F.conv2d(u, weights_1, bias_1)
elif dim == 3:
weights_1 = weights_1.unsqueeze(-1).unsqueeze(-1)
weights_2 = weights_2.unsqueeze(-1).unsqueeze(-1)
u_hidden = F.conv3d(u, weights_1, bias_1)
# compute derivative hidden layer
diff_tanh = 1 / torch.cosh(u_hidden) ** 2
# compute diff(f(g))
diff_fg = torch.einsum(
"mi" + dim_str + ",bm" + dim_str + ",km" + dim_str + "->bi" + dim_str,
weights_1,
diff_tanh,
weights_2,
)
# compute diff(f(g)) * diff(g)
vx = [
torch.einsum("bi" + dim_str + ",bi" + dim_str + "->b" + dim_str, diff_fg, w)
for w in ux
]
vx = [torch.unsqueeze(w, dim=1) for w in vx]
return vx
[docs]
def second_order_pino_grads(
u: Tensor,
ux: Tensor,
uxx: Tensor,
weights_1: Tensor,
weights_2: Tensor,
bias_1: Tensor,
) -> Tuple[Tensor]: # pragma: no cover
"""
Compute second order derivatives of output variables
"""
# dim for einsum
dim = len(u.shape) - 2
dim_str = "xyz"[:dim]
# compute first order derivatives of input
# compute first layer
if dim == 1:
u_hidden = F.conv1d(u, weights_1, bias_1)
elif dim == 2:
weights_1 = weights_1.unsqueeze(-1)
weights_2 = weights_2.unsqueeze(-1)
u_hidden = F.conv2d(u, weights_1, bias_1)
elif dim == 3:
weights_1 = weights_1.unsqueeze(-1).unsqueeze(-1)
weights_2 = weights_2.unsqueeze(-1).unsqueeze(-1)
u_hidden = F.conv3d(u, weights_1, bias_1)
# compute derivative hidden layer
diff_tanh = 1 / torch.cosh(u_hidden) ** 2
# compute diff(f(g))
diff_fg = torch.einsum(
"mi" + dim_str + ",bm" + dim_str + ",km" + dim_str + "->bi" + dim_str,
weights_1,
diff_tanh,
weights_2,
)
# compute diagonal of hessian
# double derivative of hidden layer
diff_diff_tanh = -2 * diff_tanh * torch.tanh(u_hidden)
# compute diff(g) * hessian(f) * diff(g)
vxx1 = [
torch.einsum(
"bi"
+ dim_str
+ ",mi"
+ dim_str
+ ",bm"
+ dim_str
+ ",mj"
+ dim_str
+ ",bj"
+ dim_str
+ "->b"
+ dim_str,
w,
weights_1,
weights_2 * diff_diff_tanh,
weights_1,
w,
)
for w in ux
] # (b,x,y,t)
# compute diff(f) * hessian(g)
vxx2 = [
torch.einsum("bi" + dim_str + ",bi" + dim_str + "->b" + dim_str, diff_fg, w)
for w in uxx
]
vxx = [torch.unsqueeze(a + b, dim=1) for a, b in zip(vxx1, vxx2)]
return vxx