Source code for physicsnemo.datapipes.transforms.geometric

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Geometric transforms for spatial data processing.

Provides transforms for computing signed distance fields, normals,
and applying spatial invariances (translation, scaling).
"""

from __future__ import annotations

from typing import Optional

import torch
from tensordict import TensorDict

from physicsnemo.datapipes.registry import register
from physicsnemo.datapipes.transforms.base import Transform
from physicsnemo.nn.functional import signed_distance_field


[docs] @register() class ComputeSDF(Transform): r""" Compute signed distance field from a mesh. Computes the signed distance from query points to the nearest point on a triangular mesh surface. Optionally returns the closest points on the mesh surface for each query point. Parameters ---------- input_keys : list[str] List of keys containing query points to compute SDF for. Each tensor should have shape :math:`(N, 3)`. output_key : str Key to store the computed SDF values. mesh_coords_key : str Key for mesh vertex coordinates, shape :math:`(M, 3)`. mesh_faces_key : str Key for mesh face indices (flattened), shape :math:`(F*3,)`. use_winding_number : bool, default=True If True, use winding number for sign determination. closest_points_key : str, optional Optional key to store closest points on mesh. Examples -------- >>> transform = ComputeSDF( ... input_keys=["volume_mesh_centers"], ... output_key="sdf_nodes", ... mesh_coords_key="stl_coordinates", ... mesh_faces_key="stl_faces", ... closest_points_key="closest_points" ... ) >>> sample = TensorDict({ ... "volume_mesh_centers": torch.randn(10000, 3), ... "stl_coordinates": torch.randn(5000, 3), ... "stl_faces": torch.randint(0, 5000, (10000,)) ... }) >>> result = transform(sample) >>> print(result["sdf_nodes"].shape) torch.Size([10000, 1]) """ def __init__( self, input_keys: list[str], output_key: str, mesh_coords_key: str, mesh_faces_key: str, *, use_winding_number: bool = True, closest_points_key: Optional[str] = None, ) -> None: """ Initialize the SDF computation transform. Parameters ---------- input_keys : list[str] List of keys containing query points to compute SDF for. Each tensor should have shape :math:`(N, 3)`. output_key : str Key to store the computed SDF values. mesh_coords_key : str Key for mesh vertex coordinates, shape :math:`(M, 3)`. mesh_faces_key : str Key for mesh face indices (flattened), shape :math:`(F*3,)`. use_winding_number : bool, default=True If True, use winding number for sign determination. closest_points_key : str, optional Optional key to store closest points on mesh. """ super().__init__() self.input_keys = input_keys self.output_key = output_key self.mesh_coords_key = mesh_coords_key self.mesh_faces_key = mesh_faces_key self.use_winding_number = use_winding_number self.closest_points_key = closest_points_key def __call__(self, data: TensorDict) -> TensorDict: """ Compute SDF for the sample. Parameters ---------- data : TensorDict Input TensorDict containing mesh and query point data. Returns ------- TensorDict TensorDict with computed SDF values added. Raises ------ KeyError If mesh or query point keys are not found in the data. """ # Get mesh data if self.mesh_coords_key not in data: raise KeyError(f"Mesh coordinates key '{self.mesh_coords_key}' not found") if self.mesh_faces_key not in data: raise KeyError(f"Mesh faces key '{self.mesh_faces_key}' not found") mesh_coords = data[self.mesh_coords_key] mesh_faces = data[self.mesh_faces_key].to(torch.int32) updates = {} # Compute SDF for each input key for key in self.input_keys: if key not in data: raise KeyError(f"Input key '{key}' not found") query_points = data[key] # Compute SDF and closest points sdf, closest_points = signed_distance_field( mesh_coords, mesh_faces, query_points, use_sign_winding_number=self.use_winding_number, ) # Store SDF with output key (add suffix if multiple inputs) if len(self.input_keys) == 1: updates[self.output_key] = sdf.reshape(-1, 1) if self.closest_points_key is not None: updates[self.closest_points_key] = closest_points else: suffix = f"_{key}" updates[f"{self.output_key}{suffix}"] = sdf.reshape(-1, 1) if self.closest_points_key is not None: updates[f"{self.closest_points_key}{suffix}"] = closest_points return data.update(updates) def __repr__(self) -> str: """ Return string representation. Returns ------- str String representation of the transform. """ return f"ComputeSDF(input_keys={self.input_keys}, output_key={self.output_key})"
[docs] @register() class ComputeNormals(Transform): r""" Compute normal vectors from closest points. Computes normalized direction vectors from query points to their closest points on a surface. Handles zero-distance edge cases by falling back to center of mass direction. Parameters ---------- positions_key : str Key for position tensor, shape :math:`(N, 3)`. closest_points_key : str Key for closest points tensor, shape :math:`(N, 3)`. center_of_mass_key : str Key for center of mass, shape :math:`(1, 3)` or :math:`(3,)`. output_key : str Key to store computed normals. handle_zero_distance : bool, default=True If True, use center_of_mass fallback for zero distances. Examples -------- >>> transform = ComputeNormals( ... positions_key="volume_mesh_centers", ... closest_points_key="closest_points", ... center_of_mass_key="center_of_mass", ... output_key="volume_normals" ... ) """ def __init__( self, positions_key: str, closest_points_key: str, center_of_mass_key: str, output_key: str, *, handle_zero_distance: bool = True, ) -> None: """ Initialize the normal computation transform. Parameters ---------- positions_key : str Key for position tensor, shape :math:`(N, 3)`. closest_points_key : str Key for closest points tensor, shape :math:`(N, 3)`. center_of_mass_key : str Key for center of mass, shape :math:`(1, 3)` or :math:`(3,)`. output_key : str Key to store computed normals. handle_zero_distance : bool, default=True If True, use center_of_mass fallback for zero distances. """ super().__init__() self.positions_key = positions_key self.closest_points_key = closest_points_key self.center_of_mass_key = center_of_mass_key self.output_key = output_key self.handle_zero_distance = handle_zero_distance def __call__(self, data: TensorDict) -> TensorDict: """ Compute normals for the sample. Parameters ---------- data : TensorDict Input TensorDict containing position and closest point data. Returns ------- TensorDict TensorDict with computed normals added. """ positions = data[self.positions_key] closest_points = data[self.closest_points_key] center_of_mass = data[self.center_of_mass_key] # Ensure center_of_mass has shape (1, 3) if center_of_mass.ndim == 1: center_of_mass = center_of_mass.unsqueeze(0) # Compute initial normals normals = positions - closest_points if self.handle_zero_distance: # Handle zero-distance points (on or very close to surface) distance_to_closest = torch.norm(normals, dim=-1) null_points = distance_to_closest < 1e-6 # For null points, use direction from center of mass if null_points.any(): normals[null_points] = positions[null_points] - center_of_mass # Normalize norm = torch.norm(normals, dim=-1, keepdim=True) + 1e-6 normals = normals / norm return data.update({self.output_key: normals}) def __repr__(self) -> str: """ Return string representation. Returns ------- str String representation of the transform. """ return ( f"ComputeNormals(positions_key={self.positions_key}, " f"output_key={self.output_key})" )
[docs] @register() class Translate(Transform): r""" Apply a translation by adding or subtracting a center point. By default, this will ADD the translation. But you can also use the subtract mode: this is particularly useful when composed with CenterOfMass: you can compute the CoM and apply a translation as a CoM subtraction to center the data at the origin. Parameters ---------- input_keys : list[str] List of position tensor keys to translate. center_key_or_value : str or torch.Tensor Either a key name (str) for a tensor in the sample, or a fixed tensor value to add/subtract. subtract : bool, default=False If False (default), ADD the translation (data + center). If True, SUBTRACT the translation (data - center). Use subtract=True when centering data around a reference point like center of mass. Examples -------- Add mode (default) - shift points by a fixed offset: >>> transform = Translate( ... input_keys=["positions"], ... center_key_or_value=torch.tensor([1.0, 2.0, 3.0]) ... ) >>> # result["positions"] = original + [1, 2, 3] Subtract mode - center points by subtracting center of mass: >>> transform = Translate( ... input_keys=["volume_mesh_centers", "surface_mesh_centers"], ... center_key_or_value="center_of_mass", ... subtract=True ... ) >>> # result["positions"] = original - center_of_mass """ def __init__( self, input_keys: list[str], center_key_or_value: str | torch.Tensor, *, subtract: bool = False, ) -> None: """ Initialize the translation transform. Parameters ---------- input_keys : list[str] List of position tensor keys to translate. center_key_or_value : str or torch.Tensor Either a key name (str) for a tensor in the sample, or a fixed tensor value to add/subtract. subtract : bool, default=False If False (default), ADD the translation (data + center). If True, SUBTRACT the translation (data - center). Use subtract=True when centering data around a reference point like center of mass. """ super().__init__() self.input_keys = input_keys self.center_key_or_value = center_key_or_value self.is_key = isinstance(center_key_or_value, str) self.subtract = subtract def __call__(self, data: TensorDict) -> TensorDict: """ Apply translation to the sample. Parameters ---------- data : TensorDict Input TensorDict containing position data. Returns ------- TensorDict TensorDict with translated positions. Raises ------ KeyError If the center key is not found in the data. TypeError If center_key_or_value is not a string or torch.Tensor. """ # Get center value if isinstance(self.center_key_or_value, str): if self.center_key_or_value not in data: raise KeyError(f"Center key '{self.center_key_or_value}' not found") center = data[self.center_key_or_value] else: if not isinstance(self.center_key_or_value, torch.Tensor): raise TypeError( f"center_key_or_value should be torch.Tensor but got {type(self.center_key_or_value)}" ) center = self.center_key_or_value # Move to same device as data if needed if data.device is not None and center.device != data.device: center = center.to(data.device) # Ensure center has shape (1, 3) or (1, D) if center.ndim == 1: center = center.unsqueeze(0) # Apply translation to all keys updates = {} for key in self.input_keys: if key in data: if self.subtract: updates[key] = data[key] - center else: updates[key] = data[key] + center return data.update(updates) def __repr__(self) -> str: """ Return string representation. Returns ------- str String representation of the transform. """ mode = "subtract" if self.subtract else "add" return ( f"Translate(input_keys={self.input_keys}, " f"center={self.center_key_or_value}, mode={mode})" )
[docs] @register() class Scale(Transform): r""" Apply a scale factor by multiplying or dividing by a reference scale. By default, this will MULTIPLY by the scale factor. But you can also use the divide mode: this is particularly useful for normalizing data to make the representation scale invariant (e.g., dividing by a characteristic length scale). Parameters ---------- input_keys : list[str] List of position tensor keys to scale. scale : torch.Tensor Scale factor tensor, shape :math:`(1, D)` or :math:`(D,)`. divide : bool, default=False If False (default), MULTIPLY by the scale (data * scale). If True, DIVIDE by the scale (data / scale). Use divide=True when normalizing data by a reference scale. Examples -------- Multiply mode (default) - scale up positions by 2x: >>> transform = Scale( ... input_keys=["positions"], ... scale=torch.tensor([2.0, 2.0, 2.0]) ... ) >>> # result["positions"] = original * [2, 2, 2] Divide mode - normalize by a reference scale: >>> transform = Scale( ... input_keys=["volume_mesh_centers", "geometry_coordinates"], ... scale=torch.tensor([1.0, 1.0, 1.0]), ... divide=True ... ) >>> # result["positions"] = original / scale """ def __init__( self, input_keys: list[str], scale: torch.Tensor, *, divide: bool = False, ) -> None: """ Initialize the scale transform. Parameters ---------- input_keys : list[str] List of position tensor keys to scale. scale : torch.Tensor Scale factor tensor, shape :math:`(1, D)` or :math:`(D,)`. divide : bool, default=False If False (default), MULTIPLY by the scale (data * scale). If True, DIVIDE by the scale (data / scale). Use divide=True when normalizing data by a reference scale. """ super().__init__() self.input_keys = input_keys self.scale = scale self.divide = divide def __call__(self, data: TensorDict) -> TensorDict: """ Apply scaling to the data. Parameters ---------- data : TensorDict Input TensorDict containing position data. Returns ------- TensorDict TensorDict with scaled positions. """ scale = self.scale # Ensure scale has batch dimension if scale.ndim == 1: scale = scale.unsqueeze(0) # Move scale to same device as data if needed if data.device is not None and scale.device != data.device: scale = scale.to(data.device) # Apply scaling to all keys updates = {} for key in self.input_keys: if key in data: if self.divide: updates[key] = data[key] / scale else: updates[key] = data[key] * scale return data.update(updates) def __repr__(self) -> str: """ Return string representation. Returns ------- str String representation of the transform. """ mode = "divide" if self.divide else "multiply" return ( f"Scale(input_keys={self.input_keys}, " f"scale_shape={self.scale.shape}, mode={mode})" )