Source code for physicsnemo.nn.module.fft

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import List, Optional, Tuple

import torch
import torch.fft
import torch.onnx
from torch import Tensor
from torch.autograd import Function

# Note 1: for DFT operators, the less verbose way of registering an operator is via
# `register_custom_op_symbolic`. However, it does not currently work due to
# torch.fft.rfft* functions returning Complex type which is not yet supported in ONNX.

# Note 2:
# - current ONNX Contrib implementation does not support configurable normalization, so
#   "normalized" must be 0, the normalization is done outside of Contrib ops.
#   See also comments in `_scale_output_backward` function for more details.
# - "onesided" is not configurable either - must be set to 1.
# - Contrib implementation requires DFT dimensions to be the last ones,
#   otherwise axes permutation is required.
# See:
# https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/math/fft_ops.h#L19


[docs] def rfft( input: Tensor, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None, ) -> Tensor: """ONNX compatable method to compute the 1d Fourier transform of real-valued input. Parameters ---------- input : Tensor Real input tensor n : Optional[int], optional Signal strength, by default None dim : int, optional Dimension along which to take the real FFT, by default -1 norm : Optional[str], optional Normalization mode with options "forward", "backward and "ortho". When set to None, normalization will default to backward (no normalization), by default None Note ---- The function is equivalent to `torch.fft.rfft` when not running in ONNX export mode """ if not torch.onnx.is_in_onnx_export(): return torch.fft.rfft(input, n=n, dim=dim, norm=norm) if not isinstance(dim, int): raise TypeError() return _rfft_onnx(input, (n,), (dim,), norm)
[docs] def rfft2( input: Tensor, s: Optional[Tuple[int]] = None, dim: Tuple[int] = (-2, -1), norm: Optional[str] = None, ) -> Tensor: """ONNX compatable method to compute the 2d Fourier transform of real-valued input. Parameters ---------- input : Tensor Real input tensor s : Optional[Tuple[int]], optional Signal size in the transformed dimensions, by default None dim : Tuple[int], optional Dimensions along which to take the real 2D FFT, by default (-2, -1) norm : Optional[str], optional Normalization mode with options "forward", "backward" and "ortho". When set to None, normalization will default to backward (normalize by 1/n), by default None Note ---- The function is equivalent to `torch.fft.rfft2` when not running in ONNX export mode """ if not torch.onnx.is_in_onnx_export(): return torch.fft.rfft2(input, s=s, dim=dim, norm=norm) if not (isinstance(dim, tuple) and len(dim) == 2): raise ValueError() return _rfft_onnx(input, s, dim, norm)
[docs] def irfft( input: Tensor, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None, ) -> Tensor: """ONNX compatable method to compute the inverse of `rfft`. Parameters ---------- input : Tensor Real input tensor n : Optional[int], optional Signal strength, by default None dim : int, optional Dimension along which to take the real IFFT, by default -1 norm : Optional[str], optional Normalization mode with options "forward", "backward" and "ortho". When set to None, normalization will default to backward (no normalization), by default None Note ---- The function is equivalent to `torch.fft.irfft` when not running in ONNX export mode """ if not torch.onnx.is_in_onnx_export(): return torch.fft.irfft(input, n=n, dim=dim, norm=norm) if not isinstance(dim, int): raise TypeError() return _irfft_onnx(input, (n,), (dim,), norm)
[docs] def irfft2( input: Tensor, s: Optional[Tuple[int]] = None, dim: Tuple[int] = (-2, -1), norm: Optional[str] = None, ) -> Tensor: """ONNX compatable method to compute the inverse of `rfft2`. Parameters ---------- input : Tensor Real input tensor s : Optional[Tuple[int]], optional Signal size in the transformed dimensions, by default None dim : Tuple[int], optional Dimensions along which to take the real 2D IFFT, by default (-2, -1) norm : Optional[str], optional Normalization mode with options "forward", "backward" and "ortho". When set to None, normalization will default to backward (normalize by 1/n), by default None Note ---- The function is equivalent to `torch.fft.irfft2` when not running in ONNX export mode """ if not torch.onnx.is_in_onnx_export(): return torch.fft.irfft2(input, s=s, dim=dim, norm=norm) if not (isinstance(dim, tuple) and len(dim) == 2): raise ValueError() return _irfft_onnx(input, s, dim, norm)
[docs] def view_as_complex(input: Tensor) -> Tensor: """ONNX compatable method to view input as complex tensor Parameters ---------- input : Tensor The input Tensor Note ---- The function is equivalent to `torch.view_as_complex` when not running in ONNX export mode Raises ------ AssertionError If input tensor shape is not [...,2] during ONNX runtime where the last dimension denotes the real / imaginary tensors """ if not torch.onnx.is_in_onnx_export(): return torch.view_as_complex(input) # Just return the input unchanged - during ONNX export # there will be no complex type. if input.size(-1) != 2: raise ValueError return input
[docs] def real(input: Tensor) -> Tensor: """ONNX compatable method to view input as real tensor Parameters ---------- input : Tensor The input Tensor Note ---- The function is equivalent to `input.real` when not running in ONNX export mode Raises ------ AssertionError If input tensor shape is not [...,2] during ONNX runtime where the last dimension denotes the real / imaginary tensors """ if not torch.onnx.is_in_onnx_export(): return input.real # There is no complex type during ONNX export, so assuming # complex numbers are represented as if after `view_as_real`. if input.size(-1) != 2: raise ValueError() return input[..., 0]
[docs] def imag(input: Tensor) -> Tensor: """ONNX compatable method to view input as imaginary tensor Parameters ---------- input : Tensor The input Tensor Note ---- The function is equivalent to `input.imag` when not running in ONNX export mode Raises ------ AssertionError If input tensor shape is not [...,2] during ONNX runtime where the last dimension denotes the real / imaginary tensors """ if not torch.onnx.is_in_onnx_export(): return input.imag # There is no complex type during ONNX export, so assuming # complex numbers are represented as if after `view_as_real`. if input.size(-1) != 2: raise ValueError(input.size(-1)) return input[..., 1]
def _rfft_onnx( input: Tensor, s: Optional[Tuple[Optional[int]]], dim: Tuple[int], norm: str ) -> Tensor: if s is not None: _check_padding_rfft(s, dim, input.size()) ndim = len(dim) if ndim not in [1, 2]: raise ValueError(ndim) perm = not _is_last_dims(dim, input.ndim) if perm: perm_in, perm_out = _create_axes_perm(input.ndim, dim) # Add a dimension to account for complex output. perm_out.append(len(perm_out)) # Transpose -> RFFT -> Transpose (inverse). input = input.permute(perm_in) rfft_func = OnnxRfft if ndim == 1 else OnnxRfft2 output = rfft_func.apply(input) output = _scale_output_forward(output, norm, input.size(), ndim) if perm: output = output.permute(perm_out) return output def _irfft_onnx( input: Tensor, s: Optional[Tuple[Optional[int]]], dim: Tuple[int], norm: str ) -> Tensor: if s is not None: _check_padding_irfft(s, dim, input.size()) ndim = len(dim) if ndim not in [1, 2]: raise ValueError(ndim) # Whether to permute axes when DFT axis is not the last. perm = not _is_last_dims(dim, input.ndim) if perm: # Do not include last dimension (input is complex). perm_in, perm_out = _create_axes_perm(input.ndim - 1, dim) # Add a dimension to account for complex input. perm_in.append(len(perm_in)) # Transpose -> IRFFT -> Transpose (inverse). input = input.permute(perm_in) irfft_func = OnnxIrfft if ndim == 1 else OnnxIrfft2 output = irfft_func.apply(input) output = _scale_output_backward(output, norm, input.size(), ndim) if perm: output = output.permute(perm_out) return output def _contrib_rfft(g: torch.Graph, input: torch.Value, ndim: int) -> torch.Value: if ndim not in [1, 2]: raise ValueError(ndim) # See https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Rfft output = g.op( "com.microsoft::Rfft", input, normalized_i=0, onesided_i=1, signal_ndim_i=ndim, ) return output def _contrib_irfft(g: torch.Graph, input: torch.Value, ndim: int) -> torch.Value: if ndim not in [1, 2]: raise ValueError(ndim) # See https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Irfft output = g.op( "com.microsoft::Irfft", input, normalized_i=0, onesided_i=1, signal_ndim_i=ndim, ) return output def _is_last_dims(dim: Tuple[int], inp_ndim: int) -> bool: ndim = len(dim) for i, idim in enumerate(dim): # This takes care of both positive and negative axis indices. if idim % inp_ndim != inp_ndim - ndim + i: return False return True def _check_padding_rfft( sizes: Tuple[Optional[int]], dim: Tuple[int], inp_sizes: Tuple[int] ) -> None: if len(sizes) != len(dim): raise ValueError(f"{sizes}, {dim}") for i, s in enumerate(sizes): if s is None or s < 0: continue # Current Contrib RFFT does not support pad/trim yet. if s != inp_sizes[dim[i]]: raise RuntimeError( f"Padding/trimming is not yet supported, " f"got sizes {sizes}, DFT dims {dim}, " f"input dims {inp_sizes}." ) def _check_padding_irfft( sizes: Tuple[Optional[int]], dim: Tuple[int], inp_sizes: Tuple[int] ) -> None: if len(sizes) != len(dim): raise ValueError(f"{sizes}, {dim}") # All but last dims must be equal to input dims. for i, s in enumerate(sizes[:-1]): if s is None or s < 0: continue # Current Contrib RFFT does not support pad/trim yet. if s != inp_sizes[dim[i]]: raise RuntimeError( f"Padding/trimming is not yet supported, " f"got sizes {sizes}, DFT dims {dim}, " f"input dims {inp_sizes}." ) # Check last dim. s = sizes[-1] if s is not None and s > 0: expected_size = 2 * (inp_sizes[dim[-1]] - 1) if s != expected_size: raise RuntimeError( f"Padding/trimming is not yet supported, got sizes {sizes}" f", DFT dims {dim}, input dims {inp_sizes}" f", expected last size {expected_size}." ) def _create_axes_perm(ndim: int, dims: Tuple[int]) -> Tuple[List[int], List[int]]: """Creates permuted axes indices for RFFT/IRFFT operators.""" perm_in = list(range(ndim)) perm_out = list(perm_in) # Move indices to the right to make 'dims' as innermost dimensions. for i in range(-1, -(len(dims) + 1), -1): perm_in[dims[i]], perm_in[i] = perm_in[i], perm_in[dims[i]] # Move indices to the left to restore original shape. for i in range(-len(dims), 0): perm_out[dims[i]], perm_out[i] = perm_out[i], perm_out[dims[i]] return perm_in, perm_out def _scale_output_forward( output: Tensor, norm: str, sizes: torch.Size, ndim: int ) -> Tensor: """Scales the RFFT output according to norm parameter.""" norm = "backward" if norm is None else norm if norm not in ["forward", "backward", "ortho"]: raise ValueError(norm) # No normalization for "backward" in RFFT ops. if norm in ["forward", "ortho"]: # Assuming DFT dimensions are the last. This is required by the current Contrib ops, # so the axes permutation of the input is done accordingly. dft_size = math.prod(sizes[-ndim:]).float() denom = torch.sqrt(dft_size) if norm == "ortho" else dft_size output = output / denom return output def _scale_output_backward( output: Tensor, norm: str, sizes: torch.Size, ndim: int ) -> Tensor: """Scales the IRFFT output according to norm parameter.""" norm = "backward" if norm is None else norm if norm not in ["forward", "backward", "ortho"]: raise ValueError(norm) # Things get interesting here: Contrib IRFFT op uses cuFFT cufftXtExec # followed by a custom CUDA kernel (`_Normalize`) which always performs # normalization (division by N) which means "norm" is essentially # always "backward" here. So we need to cancel this normalization # when norm is "forward" or "ortho". if norm in ["forward", "ortho"]: # Last dimension is complex numbers representation. # Second-to-last dim corresponds to last dim in RFFT transform. # This is required by the current Contrib ops, # so the axes permutation of the input is done previously. if not len(sizes) >= ndim + 1: raise ValueError dft_size = math.prod(sizes[-(ndim + 1) : -2]) dft_size *= 2 * (sizes[-2] - 1) dft_size = dft_size.float() # Since cuFFT scales by 1/dft_size, replace this scale with appropriate one. scale = dft_size if norm == "forward" else torch.sqrt(dft_size) output = scale * output return output
[docs] class OnnxRfft(Function): """Auto-grad function to mimic rfft for ONNX exporting Note ---- Should only be called during an ONNX export """
[docs] @staticmethod def forward(ctx, input: Tensor) -> Tensor: if not torch.onnx.is_in_onnx_export(): raise ValueError("Must be called only during ONNX export.") # We need to mimic the behavior of Contrib RFFT which assumes # DFT of last dim and no normalization. y = torch.fft.rfft(input, dim=-1, norm="backward") return torch.view_as_real(y)
[docs] @staticmethod def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value: """Symbolic representation for onnx graph""" return _contrib_rfft(g, input, ndim=1)
[docs] class OnnxRfft2(Function): """Auto-grad function to mimic rfft2 for ONNX exporting Note ---- Should only be called during an ONNX export """
[docs] @staticmethod def forward(ctx, input: Tensor) -> Tensor: if not torch.onnx.is_in_onnx_export(): raise AssertionError("Must be called only during ONNX export.") # We need to mimic the behavior of Contrib RFFT which assumes # DFT of last dims and no normalization. y = torch.fft.rfft2(input, dim=(-2, -1), norm="backward") return torch.view_as_real(y)
[docs] @staticmethod def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value: """Symbolic representation for onnx graph""" return _contrib_rfft(g, input, ndim=2)
[docs] class OnnxIrfft(Function): """Auto-grad function to mimic irfft for ONNX exporting Note ---- Should only be called during an ONNX export """
[docs] @staticmethod def forward(ctx, input: Tensor) -> Tensor: if not torch.onnx.is_in_onnx_export(): raise ValueError("Must be called only during ONNX export.") # We need to mimic the behavior of Contrib IRFFT which assumes # DFT of last dim and 1/n normalization. return torch.fft.irfft(torch.view_as_complex(input), dim=-1, norm="backward")
[docs] @staticmethod def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value: """Symbolic representation for onnx graph""" return _contrib_irfft(g, input, ndim=1)
[docs] class OnnxIrfft2(Function): """Auto-grad function to mimic irfft2 for ONNX exporting. Note ---- Should only be called during an ONNX export """
[docs] @staticmethod def forward(ctx, input: Tensor) -> Tensor: if not torch.onnx.is_in_onnx_export(): raise AssertionError("Must be called only during ONNX export.") # We need to mimic the behavior of Contrib IRFFT which assumes # DFT of last dims and 1/n normalization. return torch.fft.irfft2( torch.view_as_complex(input), dim=(-2, -1), norm="backward" )
[docs] @staticmethod def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value: """Symbolic representation for onnx graph""" return _contrib_irfft(g, input, ndim=2)