Source code for physicsnemo.nn.module.fourier_layers

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

from .mlp_layers import Mlp


[docs] def fourier_encode(coords: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: """Vectorized Fourier feature encoding Args: coords: Tensor containing coordinates, of shape (batch_size, D) freqs: Tensor containing frequencies, of shape (F,) (num frequencies) Returns: Tensor containing Fourier features, of shape (batch_size, D * 2 * F) """ D = coords.shape[-1] F = freqs.shape[0] freqs = freqs[None, None, :, None] # reshape to [*, F, 1] for broadcasting coords = coords.unsqueeze(-2) # [*, 1, D] scaled = (coords * freqs).reshape(*coords.shape[:-2], D * F) # [*, D, F] features = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1) # [*, D, 2F] return features.reshape(*coords.shape[:-2], D * 2 * F) # [*, D * 2F]
[docs] class FourierMLP(nn.Module): """ This is an MLP that will, optionally, fourier encode the input features. The encoded features are concatenated to the original inputs, and then processed with an MLP. Args: input_features: The number of input features to the MLP. base_layer: The number of neurons in the hidden layer of the MLP. fourier_features: Whether to fourier encode the input features. num_modes: The number of modes to use for the fourier encoding. activation: The activation function to use in the MLP. """ def __init__( self, input_features: int, base_layer: int, fourier_features: bool, num_modes: int, activation: nn.Module | str, ): super().__init__() self.fourier_features = fourier_features # self.num_modes = model_parameters.num_modes if self.fourier_features: input_features_calculated = input_features + input_features * num_modes * 2 self.register_buffer( "freqs", torch.exp(torch.linspace(0, math.pi, num_modes)) ) else: input_features_calculated = input_features self.mlp = Mlp( in_features=input_features_calculated, hidden_features=[ base_layer, base_layer, ], out_features=base_layer, act_layer=activation, drop=0.0, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if self.fourier_features: x = torch.cat((x, fourier_encode(x, self.freqs)), dim=-1) return self.mlp(x)
[docs] class FourierLayer(nn.Module): """Fourier layer used in the Fourier feature network""" def __init__( self, in_features: int, frequencies, ) -> None: super().__init__() # To do: Need more robust way for these params if isinstance(frequencies[0], str): if "gaussian" in frequencies[0]: nr_freq = frequencies[2] np_f = ( np.random.normal(0, 1, size=(nr_freq, in_features)) * frequencies[1] ) else: nr_freq = len(frequencies[1]) np_f = [] if "full" in frequencies[0]: np_f_i = np.meshgrid( *[np.array(frequencies[1]) for _ in range(in_features)], indexing="ij", ) np_f.append( np.reshape( np.stack(np_f_i, axis=-1), (nr_freq**in_features, in_features), ) ) if "axis" in frequencies[0]: np_f_i = np.zeros((nr_freq, in_features, in_features)) for i in range(in_features): np_f_i[:, i, i] = np.reshape( np.array(frequencies[1]), (nr_freq) ) np_f.append( np.reshape(np_f_i, (nr_freq * in_features, in_features)) ) if "diagonal" in frequencies[0]: np_f_i = np.reshape(np.array(frequencies[1]), (nr_freq, 1, 1)) np_f_i = np.tile(np_f_i, (1, in_features, in_features)) np_f_i = np.reshape(np_f_i, (nr_freq * in_features, in_features)) np_f.append(np_f_i) np_f = np.concatenate(np_f, axis=-2) else: np_f = frequencies # [nr_freq, in_features] frequencies = torch.tensor(np_f, dtype=torch.get_default_dtype()) frequencies = frequencies.t().contiguous() self.register_buffer("frequencies", frequencies) def out_features(self) -> int: return int(self.frequencies.size(1) * 2)
[docs] def forward(self, x: Tensor) -> Tensor: x_hat = torch.matmul(x, self.frequencies) x_sin = torch.sin(2.0 * math.pi * x_hat) x_cos = torch.cos(2.0 * math.pi * x_hat) x_i = torch.cat([x_sin, x_cos], dim=-1) return x_i
[docs] class FourierFilter(nn.Module): """Fourier filter used in the multiplicative filter network""" def __init__( self, in_features: int, layer_size: int, nr_layers: int, input_scale: float, ) -> None: super().__init__() self.weight_scale = input_scale / math.sqrt(nr_layers + 1) self.frequency = nn.Parameter(torch.empty(in_features, layer_size)) # The shape of phase tensor was supposed to be [1, layer_size], but it has issue # with batched tensor in FuncArch. # We could just rely on broadcast here. self.phase = nn.Parameter(torch.empty(layer_size)) self.reset_parameters()
[docs] def reset_parameters(self) -> None: """Resets parameters""" nn.init.xavier_uniform_(self.frequency) nn.init.uniform_(self.phase, -math.pi, math.pi)
[docs] def forward(self, x: Tensor) -> Tensor: frequency = self.weight_scale * self.frequency x_i = torch.sin(torch.matmul(x, 2.0 * math.pi * frequency) + self.phase) return x_i
[docs] class GaborFilter(nn.Module): """Gabor filter used in the multiplicative filter network""" def __init__( self, in_features: int, layer_size: int, nr_layers: int, input_scale: float, alpha: float, beta: float, ) -> None: super().__init__() self.layer_size = layer_size self.alpha = alpha self.beta = beta self.weight_scale = input_scale / math.sqrt(nr_layers + 1) self.frequency = nn.Parameter(torch.empty(in_features, layer_size)) self.phase = nn.Parameter(torch.empty(layer_size)) self.mu = nn.Parameter(torch.empty(in_features, layer_size)) self.gamma = nn.Parameter(torch.empty(layer_size)) self.reset_parameters()
[docs] def reset_parameters(self) -> None: """Resets parameters""" nn.init.xavier_uniform_(self.frequency) nn.init.uniform_(self.phase, -math.pi, math.pi) nn.init.uniform_(self.mu, -1.0, 1.0) with torch.no_grad(): self.gamma.copy_( torch.from_numpy( np.random.gamma(self.alpha, 1.0 / self.beta, (self.layer_size)), ) )
[docs] def forward(self, x: Tensor) -> Tensor: frequency = self.weight_scale * (self.frequency * self.gamma.sqrt()) x_c = x.unsqueeze(-1) x_c = x_c - self.mu # The norm dim changed from 1 to -2 to be compatible with BatchedTensor x_c = torch.square(x_c.norm(p=2, dim=-2)) x_c = torch.exp(-0.5 * x_c * self.gamma) x_i = x_c * torch.sin(torch.matmul(x, 2.0 * math.pi * frequency) + self.phase) return x_i