Source code for physicsnemo.nn.functional.interpolation.grid_to_point_interpolation.grid_to_point_interpolation

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import List, Tuple

import torch
from torch import Tensor

from physicsnemo.core.function_spec import FunctionSpec

from ._torch_impl import interpolation_torch
from ._warp_impl import interpolation_warp


class GridToPointInterpolation(FunctionSpec):
    r"""Interpolate values from a structured grid at query point locations.

    This functional evaluates a scalar or multi-channel field defined on a regular
    Cartesian grid at arbitrary query points in 1D, 2D, or 3D.

    For a query point :math:`\mathbf{x}` and a grid field :math:`f`, interpolation
    is computed as a weighted sum over local stencil points:

    .. math::

       \hat{f}(\mathbf{x}) = \sum_{i \in \mathcal{N}(\mathbf{x})}
       w_i(\mathbf{x})\, f_i

    where :math:`\mathcal{N}(\mathbf{x})` is the interpolation neighborhood and
    :math:`w_i(\mathbf{x})` are interpolation weights.

    The interpolation mode controls the stencil and weights:

    - ``nearest_neighbor``: nearest grid point (piecewise constant, 1-point stencil)
    - ``linear``: multilinear interpolation (2^d stencil in d dimensions)
    - ``smooth_step_1``: multilinear-style interpolation with smooth-step weights
      :math:`3t^2 - 2t^3`
    - ``smooth_step_2``: multilinear-style interpolation with quintic smooth-step
      weights :math:`t^3(6t^2 - 15t + 10)`
    - ``gaussian``: local Gaussian weighting over a larger fixed stencil

    Notes
    -----
    - Grid spacing and extents are provided by ``grid``.
    - The ``warp`` and ``torch`` backends are intended to be numerically aligned.
    - ``warp`` is the default dispatch path for ``grid_to_point_interpolation``.
    - The deprecated ``interpolation`` alias defaults to ``torch`` unless an
      explicit ``implementation`` is provided.

    Parameters
    ----------
    query_points: torch.Tensor
        Points at which interpolation is to be performed.
    context_grid: torch.Tensor
        Source grid from which values are interpolated.
    grid: list[tuple[float, float, int]]
        Describes the grid's range and resolution.
    interpolation_type: str, optional
        Interpolation method name, by default ``"smooth_step_2"``.
    mem_speed_trade: bool, optional
        Trade-off between memory usage and speed.
    implementation : {"warp", "torch"} or None
        Implementation to use. When ``None``, dispatch selects the available
        implementation.

    """

    _BENCHMARK_CASES = (
        ("1d-nearest-g2048-n8192", 1, 2048, 8192, "nearest_neighbor"),
        ("1d-linear-g2048-n8192", 1, 2048, 8192, "linear"),
        ("2d-smooth1-g128-n1024", 2, 128, 1024, "smooth_step_1"),
        ("2d-smooth2-g128-n1024", 2, 128, 1024, "smooth_step_2"),
        ("3d-linear-g32-n512", 3, 32, 512, "linear"),
        ("3d-smooth2-g32-n512", 3, 32, 512, "smooth_step_2"),
        ("3d-gaussian-g32-n512", 3, 32, 512, "gaussian"),
    )
    _COMPARE_ATOL = 5e-5
    _COMPARE_RTOL = 1e-4
    _COMPARE_BACKWARD_ATOL = 2e-2
    _COMPARE_BACKWARD_RTOL = 5e-2

    @FunctionSpec.register(name="warp", required_imports=("warp>=0.6.0",), rank=0)
    def warp_forward(
        query_points: Tensor,
        context_grid: Tensor,
        grid: List[Tuple[float, float, int]],
        interpolation_type: str = "smooth_step_2",
        mem_speed_trade: bool = True,
    ) -> Tensor:
        return interpolation_warp(
            query_points,
            context_grid,
            grid,
            interpolation_type=interpolation_type,
            mem_speed_trade=mem_speed_trade,
        )

    @FunctionSpec.register(name="torch", rank=1, baseline=True)
    def torch_forward(
        query_points: Tensor,
        context_grid: Tensor,
        grid: List[Tuple[float, float, int]],
        interpolation_type: str = "smooth_step_2",
        mem_speed_trade: bool = True,
    ) -> Tensor:
        return interpolation_torch(
            query_points,
            context_grid,
            grid,
            interpolation_type=interpolation_type,
            mem_speed_trade=mem_speed_trade,
        )

    @classmethod
    def make_inputs_forward(cls, device: torch.device | str = "cpu"):
        device = torch.device(device)
        for label, dims, grid_size, num_points, interp_name in cls._BENCHMARK_CASES:
            grid = [(-1.0, 2.0, grid_size)] * dims
            linspace = [torch.linspace(x[0], x[1], x[2], device=device) for x in grid]
            mesh_grid = torch.meshgrid(linspace, indexing="ij")
            mesh_grid = torch.stack(mesh_grid, dim=0)
            context_grid = torch.zeros_like(mesh_grid[0:1])
            for power, coord in enumerate(mesh_grid, start=1):
                context_grid = context_grid + coord.unsqueeze(0) ** power
            context_grid = torch.sin(context_grid)
            query_points = torch.stack(
                [
                    torch.linspace(0.0, 1.0, num_points, device=device)
                    for _ in range(dims)
                ],
                axis=-1,
            )
            yield (
                label,
                (query_points, context_grid, grid),
                {"interpolation_type": interp_name, "mem_speed_trade": True},
            )

    @classmethod
    def make_inputs_backward(cls, device: torch.device | str = "cpu"):
        device = torch.device(device)
        for label, dims, grid_size, num_points, interp_name in cls._BENCHMARK_CASES:
            grid = [(-1.0, 2.0, grid_size)] * dims
            linspace = [torch.linspace(x[0], x[1], x[2], device=device) for x in grid]
            mesh_grid = torch.meshgrid(linspace, indexing="ij")
            mesh_grid = torch.stack(mesh_grid, dim=0)
            context_grid = torch.zeros_like(mesh_grid[0:1])
            for power, coord in enumerate(mesh_grid, start=1):
                context_grid = context_grid + coord.unsqueeze(0) ** power
            context_grid = torch.sin(context_grid).requires_grad_(True)
            query_points = torch.stack(
                [
                    torch.linspace(0.0, 1.0, num_points, device=device)
                    for _ in range(dims)
                ],
                axis=-1,
            ).requires_grad_(True)
            yield (
                label,
                (query_points, context_grid, grid),
                {"interpolation_type": interp_name, "mem_speed_trade": True},
            )

    @classmethod
    def compare_forward(cls, output: torch.Tensor, reference: torch.Tensor) -> None:
        torch.testing.assert_close(
            output,
            reference,
            atol=cls._COMPARE_ATOL,
            rtol=cls._COMPARE_RTOL,
        )

    @classmethod
    def compare_backward(cls, output: torch.Tensor, reference: torch.Tensor) -> None:
        torch.testing.assert_close(
            output,
            reference,
            atol=cls._COMPARE_BACKWARD_ATOL,
            rtol=cls._COMPARE_BACKWARD_RTOL,
        )


grid_to_point_interpolation = GridToPointInterpolation.make_function(
    "grid_to_point_interpolation"
)


[docs] def interpolation(*args, **kwargs): """Deprecated alias for ``grid_to_point_interpolation``.""" warnings.warn( "`interpolation` is deprecated and will be removed in a future release. " "Use `grid_to_point_interpolation` instead.", DeprecationWarning, stacklevel=2, ) # Preserve historical default behavior for the deprecated alias while still # allowing explicit backend selection overrides. kwargs.setdefault("implementation", "torch") return grid_to_point_interpolation(*args, **kwargs)
__all__ = [ "GridToPointInterpolation", "grid_to_point_interpolation", "interpolation", ]