Source code for physicsnemo.datapipes.transforms.spatial

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Spatial transforms for mesh and grid processing.

Provides generic transforms for spatial operations including bounding box
filtering, grid creation, k-NN neighbor computation, and center of mass calculation.
"""

from __future__ import annotations

from typing import Optional

import torch
from tensordict import TensorDict

from physicsnemo.datapipes.registry import register
from physicsnemo.datapipes.transforms.base import Transform
from physicsnemo.nn.functional import knn


[docs] @register() class BoundingBoxFilter(Transform): r""" Filter points outside a spatial bounding box. Removes points that fall outside specified min/max bounds and applies the same filtering to dependent arrays to maintain correspondence. This is useful for focusing on specific regions of interest or removing outliers from simulation data. Parameters ---------- input_keys : list[str] List of coordinate tensor keys to filter. bbox_min : torch.Tensor Minimum corner of bounding box, shape :math:`(3,)`. bbox_max : torch.Tensor Maximum corner of bounding box, shape :math:`(3,)`. dependent_keys : list[str], optional Optional list of keys to filter using the same mask. These maintain correspondence with the filtered coordinates. Examples -------- >>> transform = BoundingBoxFilter( ... input_keys=["volume_mesh_centers"], ... bbox_min=torch.tensor([-1.0, -1.0, -1.0]), ... bbox_max=torch.tensor([1.0, 1.0, 1.0]), ... dependent_keys=["volume_fields", "sdf_nodes"] ... ) >>> sample = TensorDict({ ... "volume_mesh_centers": torch.randn(10000, 3) * 2, # Some outside bbox ... "volume_fields": torch.randn(10000, 4) ... }) >>> result = transform(sample) >>> # Only points within bbox remain """ def __init__( self, input_keys: list[str], bbox_min: torch.Tensor, bbox_max: torch.Tensor, *, dependent_keys: Optional[list[str]] = None, ) -> None: """ Initialize the bounding box filter transform. Parameters ---------- input_keys : list[str] List of coordinate tensor keys to filter. bbox_min : torch.Tensor Minimum corner of bounding box, shape :math:`(3,)`. bbox_max : torch.Tensor Maximum corner of bounding box, shape :math:`(3,)`. dependent_keys : list[str], optional Optional list of keys to filter using the same mask. These maintain correspondence with the filtered coordinates. """ super().__init__() self.input_keys = input_keys self.bbox_min = bbox_min self.bbox_max = bbox_max self.dependent_keys = dependent_keys or [] def __call__(self, data: TensorDict) -> TensorDict: """ Apply bounding box filtering to the sample. Parameters ---------- data : TensorDict Input TensorDict containing coordinate and dependent data. Returns ------- TensorDict TensorDict with filtered points. """ updates = {} for coord_key in self.input_keys: if coord_key not in data: continue coords = data[coord_key] # Move bbox to same device bbox_min = self.bbox_min.to(coords.device) bbox_max = self.bbox_max.to(coords.device) # Create mask for points inside bbox ids_min = coords > bbox_min ids_max = coords < bbox_max ids_in_bbox = ids_min & ids_max ids_in_bbox = ids_in_bbox.all(dim=-1) # Apply mask to coordinates updates[coord_key] = coords[ids_in_bbox] # Apply same mask to dependent keys for dep_key in self.dependent_keys: if dep_key in data: updates[dep_key] = data[dep_key][ids_in_bbox] return data.update(updates) def __repr__(self) -> str: """ Return string representation. Returns ------- str String representation of the transform. """ return ( f"BoundingBoxFilter(input_keys={self.input_keys}, " f"dependent_keys={self.dependent_keys})" )
[docs] @register() class CreateGrid(Transform): r""" Create a regular 3D spatial grid. Generates a uniform grid spanning a bounding box, used for latent space representations, interpolation grids, or structured spatial queries. Parameters ---------- output_key : str Key to store the generated grid. resolution : tuple[int, int, int] Grid resolution as (nx, ny, nz). bbox_min : torch.Tensor Minimum corner of bounding box, shape :math:`(3,)`. bbox_max : torch.Tensor Maximum corner of bounding box, shape :math:`(3,)`. Examples -------- >>> transform = CreateGrid( ... output_key="grid", ... resolution=(64, 64, 64), ... bbox_min=torch.tensor([-1.0, -1.0, -1.0]), ... bbox_max=torch.tensor([1.0, 1.0, 1.0]) ... ) >>> sample = TensorDict({}) >>> result = transform(sample) >>> print(result["grid"].shape) torch.Size([262144, 3]) """ def __init__( self, output_key: str, resolution: tuple[int, int, int], bbox_min: torch.Tensor, bbox_max: torch.Tensor, ) -> None: """ Initialize the grid creation transform. Parameters ---------- output_key : str Key to store the generated grid. resolution : tuple[int, int, int] Grid resolution as (nx, ny, nz). bbox_min : torch.Tensor Minimum corner of bounding box, shape :math:`(3,)`. bbox_max : torch.Tensor Maximum corner of bounding box, shape :math:`(3,)`. """ super().__init__() self.output_key = output_key self.resolution = resolution self.bbox_min = bbox_min self.bbox_max = bbox_max def __call__(self, data: TensorDict) -> TensorDict: """ Create grid and add to sample. Parameters ---------- data : TensorDict Input TensorDict. Returns ------- TensorDict TensorDict with generated grid added. """ device = data.device if data.device is not None else torch.device("cpu") # Move bbox to device bbox_min = self.bbox_min.to(device) bbox_max = self.bbox_max.to(device) nx, ny, nz = self.resolution # Create 1D arrays for each dimension x = torch.linspace(bbox_min[0], bbox_max[0], nx, device=device) y = torch.linspace(bbox_min[1], bbox_max[1], ny, device=device) z = torch.linspace(bbox_min[2], bbox_max[2], nz, device=device) # Create meshgrid xv, yv, zv = torch.meshgrid(x, y, z, indexing="ij") # Stack into grid of shape (nx*ny*nz, 3) grid = torch.stack([xv.flatten(), yv.flatten(), zv.flatten()], dim=-1) return data.update({self.output_key: grid}) def __repr__(self) -> str: """ Return string representation. Returns ------- str String representation of the transform. """ return f"CreateGrid(output_key={self.output_key}, resolution={self.resolution})"
[docs] @register() class KNearestNeighbors(Transform): r""" Compute k-nearest neighbors in a point cloud. Finds the k nearest neighbors for each query point and extracts corresponding coordinates and other attributes. Useful for local feature aggregation in mesh networks and spatial interpolation. Parameters ---------- points_key : str Key for reference points to search, shape :math:`(N, 3)`. queries_key : str Key for query points, shape :math:`(M, 3)`. k : int Number of nearest neighbors to find. output_prefix : str, default="neighbors" Prefix for output keys. extract_keys : list[str], optional Optional list of keys to extract for neighbors (e.g., ``["normals", "areas"]``). If None, only extracts coordinates. Examples -------- >>> transform = KNearestNeighbors( ... points_key="surface_mesh_centers", ... queries_key="surface_mesh_centers_subsampled", ... k=11, ... output_prefix="surface_neighbors", ... extract_keys=["surface_normals", "surface_areas"] ... ) >>> sample = TensorDict({ ... "surface_mesh_centers": torch.randn(10000, 3), ... "surface_mesh_centers_subsampled": torch.randn(1000, 3), ... "surface_normals": torch.randn(10000, 3), ... "surface_areas": torch.rand(10000) ... }) >>> result = transform(sample) >>> # Creates: surface_neighbors_coords, surface_neighbors_normals, etc. """ def __init__( self, points_key: str, queries_key: str, k: int, *, output_prefix: str = "neighbors", extract_keys: Optional[list[str]] = None, drop_first_neighbor: bool = False, ) -> None: """ Initialize the k-NN transform. Parameters ---------- points_key : str Key for reference points to search, shape :math:`(N, 3)`. queries_key : str Key for query points, shape :math:`(M, 3)`. k : int Number of nearest neighbors to find. output_prefix : str, default="neighbors" Prefix for output keys. extract_keys : list[str], optional Optional list of keys to extract for neighbors (e.g., ``["normals", "areas"]``). If None, only extracts coordinates. """ super().__init__() self.points_key = points_key self.queries_key = queries_key self.k = k self.output_prefix = output_prefix self.extract_keys = extract_keys or [] self.drop_first_neighbor = drop_first_neighbor def __call__(self, data: TensorDict) -> TensorDict: """ Compute k-NN and extract neighbor features. Parameters ---------- data : TensorDict Input TensorDict containing points and queries. Returns ------- TensorDict TensorDict with neighbor indices, distances, and features added. Raises ------ KeyError If points or queries keys are not found in the data. """ if self.points_key not in data: raise KeyError(f"Points key '{self.points_key}' not found") if self.queries_key not in data: raise KeyError(f"Queries key '{self.queries_key}' not found") points = data[self.points_key] queries = data[self.queries_key] # Compute k-NN neighbor_indices, neighbor_distances = knn( points=points, queries=queries, k=self.k, ) updates = {} # Store indices and distances updates[f"{self.output_prefix}_indices"] = neighbor_indices updates[f"{self.output_prefix}_distances"] = neighbor_distances # Extract neighbor coordinates (skip first, which is self) if self.drop_first_neighbor: neighbor_coords = points[neighbor_indices][:, 1:] else: neighbor_coords = points[neighbor_indices] updates[f"{self.output_prefix}_coords"] = neighbor_coords # Extract additional features for neighbors for key in self.extract_keys: if key in data: if self.drop_first_neighbor: neighbor_features = data[key][neighbor_indices][:, 1:] else: neighbor_features = data[key][neighbor_indices] updates[f"{self.output_prefix}_{key}"] = neighbor_features return data.update(updates) def __repr__(self) -> str: """ Return string representation. Returns ------- str String representation of the transform. """ return ( f"KNearestNeighbors(points_key={self.points_key}, " f"queries_key={self.queries_key}, k={self.k})" )
[docs] @register() class CenterOfMass(Transform): r""" Compute weighted center of mass for a point cloud. Calculates the center of mass using area or mass weights, typically applied to mesh data where each point represents a cell with a specific area. Parameters ---------- coords_key : str Key for coordinates, shape :math:`(N, 3)`. areas_key : str Key for area weights, shape :math:`(N,)`. output_key : str Key to store the computed center of mass, shape :math:`(1, 3)`. Examples -------- >>> transform = CenterOfMass( ... coords_key="stl_centers", ... areas_key="stl_areas", ... output_key="center_of_mass" ... ) >>> sample = TensorDict({ ... "stl_centers": torch.randn(5000, 3), ... "stl_areas": torch.rand(5000) ... }) >>> result = transform(sample) >>> print(result["center_of_mass"].shape) torch.Size([3]) """ def __init__( self, coords_key: str, output_key: str, *, areas_key: str | None = None, ) -> None: """ Initialize the center of mass transform. Parameters ---------- coords_key : str Key for coordinates, shape :math:`(N, 3)`. areas_key : str Key for area weights, shape :math:`(N,)`. output_key : str Key to store the computed center of mass, shape :math:`(1, 3)`. """ super().__init__() self.coords_key = coords_key self.areas_key = areas_key self.output_key = output_key def __call__(self, data: TensorDict) -> TensorDict: """ Compute center of mass for the sample. Parameters ---------- data : TensorDict Input TensorDict containing coordinates and area weights. Returns ------- TensorDict TensorDict with computed center of mass added. Raises ------ KeyError If coordinates or areas keys are not found in the data. """ if self.coords_key not in data: raise KeyError(f"Coordinates key '{self.coords_key}' not found") coords = data[self.coords_key] if self.areas_key is not None: if self.areas_key not in data: raise KeyError(f"Areas key '{self.areas_key}' not found") areas = data[self.areas_key] # Compute weighted center of mass total_area = areas.sum() # Apply the weighting: coords = coords * areas.unsqueeze(-1) center_of_mass = coords.sum(dim=0) / total_area else: center_of_mass = coords.mean(dim=0) return data.update({self.output_key: center_of_mass}) def __repr__(self) -> str: """ Return string representation. Returns ------- str String representation of the transform. """ return ( f"CenterOfMass(coords_key={self.coords_key}, output_key={self.output_key})" )