Source code for physicsnemo.datapipes.transforms.subsample

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Subsampling transforms for point clouds and surfaces.

Provides efficient subsampling methods for large datasets, including
Poisson disk sampling and weighted sampling.
"""

from __future__ import annotations

from typing import Literal, Optional

import torch
from tensordict import TensorDict

from physicsnemo.datapipes.registry import register
from physicsnemo.datapipes.transforms.base import Transform


def poisson_sample_indices_fixed(
    N: int,
    k: int,
    device=None,
    generator: torch.Generator | None = None,
    *,
    replacement: bool = False,
) -> torch.Tensor:
    """
    Near-uniform sampler of indices for very large arrays.

    This function provides nearly uniform sampling for cases where the number
    of indices is very large (> 2^24) and :func:`torch.multinomial` cannot work.
    Unlike using :func:`torch.randperm`, there is no need to materialize and
    randomize the entire tensor of indices.

    The sampling uses exponentially distributed gaps to achieve near-uniform
    coverage. Two modes are available:

    - ``replacement=False`` (default): each gap is constrained to be at least
      one index unit, so the resulting indices are strictly increasing and
      therefore unique. Requires ``k < N`` strictly.
    - ``replacement=True``: raw exponential gaps are used. The gaps can be
      arbitrarily small, so consecutive indices may collide after flooring,
      i.e. duplicates are possible.

    Parameters
    ----------
    N : int
        Total number of available indices.
    k : int
        Number of indices to sample.
    device : torch.device, optional
        Device for the output tensor.
    generator : torch.Generator, optional
        Random generator for reproducibility.
    replacement : bool, keyword-only, default=False
        If ``False``, sample without replacement (no duplicate indices). If
        ``True``, sample with replacement (duplicates possible).

    Returns
    -------
    torch.Tensor
        Tensor of shape :math:`(k,)` containing sampled indices.

    Raises
    ------
    ValueError
        If ``replacement=False`` and ``k >= N``, since sampling ``k`` unique
        indices from ``N`` requires ``k < N``.

    Examples
    --------
    >>> indices = poisson_sample_indices_fixed(1000000, 10000)
    >>> print(indices.shape)
    torch.Size([10000])
    """
    if replacement:
        # Draw exponential gaps off of random initializations
        gaps = torch.rand(k, device=device, generator=generator).exponential_()

        summed = gaps.sum()

        # Normalize so total cumulative sum == N
        gaps *= N / summed

        # Compute cumulative positions
        idx = torch.cumsum(gaps, dim=0)

        # Shift down so range starts at 0 and ends below N
        idx -= gaps[0] / 2

        # Round to nearest integer index
        idx = torch.clamp(idx.floor().long(), min=0, max=N - 1)

        return idx

    # Without-replacement path: enforce a minimum gap of 1 index unit so that
    # flooring cumulative positions yields strictly increasing (unique) indices.
    if k >= N:
        raise ValueError(
            f"poisson_sample_indices_fixed requires k < N when "
            f"replacement=False, but got k={k} and N={N}."
        )

    # Draw exponential gaps off of random initializations. Use float64 for the
    # cumulative arithmetic: with N up to ~1e8+, float32 precision near the top
    # of the range (~8) would swallow the +1 minimum-gap shift and let
    # consecutive floored positions collide.
    gaps = torch.rand(
        k, device=device, generator=generator, dtype=torch.float64
    ).exponential_()

    # Scale so the raw exponential portion sums to (N - k); adding 1 to every
    # gap then makes the minimum gap >= 1 while keeping the total sum == N.
    gaps *= (N - k) / gaps.sum()
    gaps += 1.0

    # Compute cumulative positions; positions span [gaps[0], N] with each
    # successive position differing by at least 1.
    idx = torch.cumsum(gaps, dim=0)

    # Shift down so the first index lands near 0 (still >= 0 since gaps[0] >= 1)
    idx -= gaps[0]

    # Floor to integer indices. Because every gap is >= 1 and the cumulative
    # sum is computed in float64 (preserving the +1 separation even at large
    # N), floor(pos_{i+1}) is strictly greater than floor(pos_i), so the
    # resulting indices are unique.
    idx = torch.clamp(idx.floor().long(), min=0, max=N - 1)

    return idx


def shuffle_array(
    points: torch.Tensor,
    n_points: int,
    weights: Optional[torch.Tensor] = None,
    generator: torch.Generator | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Sample points with or without weights.

    Parameters
    ----------
    points : torch.Tensor
        Input tensor to sample from, shape :math:`(N, ...)`.
    n_points : int
        Number of points to sample.
    weights : torch.Tensor, optional
        Optional weights for sampling, shape :math:`(N,)`.
        If None, uses uniform sampling.
    generator : torch.Generator, optional
        Random generator for reproducibility.

    Returns
    -------
    sampled_points : torch.Tensor
        Sampled tensor, shape :math:`(n\\_points, ...)`.
    indices : torch.Tensor
        Selected indices, shape :math:`(n\\_points,)`.
    """
    N = points.shape[0]
    device = points.device

    if N < n_points:
        # If not enough points, return all points
        indices = torch.arange(N, device=device)
        return points, indices

    if weights is not None:
        # Weighted sampling
        indices = torch.multinomial(
            weights, n_points, replacement=False, generator=generator
        )
    else:
        # Uniform sampling
        if N > 2**24:
            # Use Poisson sampling for very large arrays
            indices = poisson_sample_indices_fixed(
                N, n_points, device=device, generator=generator
            )
        else:
            # Use standard multinomial for smaller arrays
            indices = torch.randperm(N, device=device, generator=generator)[:n_points]

    sampled_points = points[indices]
    return sampled_points, indices


[docs] @register() class SubsamplePoints(Transform): r""" Subsample points from large point clouds or meshes. This transform applies coordinated subsampling to multiple tensor fields, ensuring that the same points are selected across all specified keys. Useful for downsampling large volumetric data or point clouds while maintaining correspondence between coordinates and field values. Supports two sampling algorithms: - ``"poisson_fixed"``: Near-uniform sampling for very large datasets (> 2^24 points) - ``"uniform"``: Standard uniform sampling Optionally supports weighted sampling (e.g., area-weighted for surface meshes) by providing a ``weights_key``. Parameters ---------- input_keys : list[str] List of tensor keys to subsample. All must have the same first dimension size. n_points : int Number of points to sample. algorithm : {"poisson_fixed", "uniform"}, default="poisson_fixed" Sampling algorithm to use. weights_key : str, optional Optional key for sampling weights (e.g., ``"surface_areas"`` for area-weighted surface sampling). When provided, samples are drawn according to the weights distribution. Examples -------- Uniform sampling: >>> transform = SubsamplePoints( ... input_keys=["volume_mesh_centers", "volume_fields"], ... n_points=10000, ... algorithm="poisson_fixed" ... ) >>> sample = TensorDict({ ... "volume_mesh_centers": torch.randn(100000, 3), ... "volume_fields": torch.randn(100000, 5) ... }) >>> result = transform(sample) >>> print(result["volume_mesh_centers"].shape) torch.Size([10000, 3]) Weighted sampling: >>> transform = SubsamplePoints( ... input_keys=["surface_mesh_centers", "surface_fields", "surface_normals"], ... n_points=5000, ... algorithm="uniform", ... weights_key="surface_areas" ... ) >>> sample = TensorDict({ ... "surface_mesh_centers": torch.randn(20000, 3), ... "surface_fields": torch.randn(20000, 2), ... "surface_normals": torch.randn(20000, 3), ... "surface_areas": torch.rand(20000) ... }) >>> result = transform(sample) >>> print(result["surface_mesh_centers"].shape) torch.Size([5000, 3]) Notes ----- All specified keys must have the same size in their first dimension. The same indices are applied to all keys to maintain correspondence. """ def __init__( self, input_keys: list[str], n_points: int, *, algorithm: Literal["poisson_fixed", "uniform"] = "poisson_fixed", weights_key: Optional[str] = None, ) -> None: """ Initialize the subsample transform. Parameters ---------- input_keys : list[str] List of tensor keys to subsample. All must have the same first dimension size. n_points : int Number of points to sample. algorithm : {"poisson_fixed", "uniform"}, default="poisson_fixed" Sampling algorithm to use. weights_key : str, optional Optional key for sampling weights (e.g., ``"surface_areas"`` for area-weighted surface sampling). When provided, samples are drawn according to the weights distribution. """ super().__init__() self.input_keys = input_keys self.n_points = n_points self.algorithm = algorithm self.weights_key = weights_key self._generator: torch.Generator | None = None def __call__(self, data: TensorDict) -> TensorDict: """ Apply subsampling to the TensorDict. Parameters ---------- data : TensorDict Input TensorDict containing fields to subsample. Returns ------- TensorDict TensorDict with subsampled fields. Raises ------ KeyError If a required key is not found in the data. ValueError If keys have inconsistent first dimension sizes. """ if not self.input_keys: return data # Check that all keys are present for key in self.input_keys: if key not in data.keys(): raise KeyError( f"Key '{key}' not found in data. " f"Available keys: {list(data.keys())}" ) # Get the first key to determine indices first_key = self.input_keys[0] first_tensor = data[first_key] N = first_tensor.shape[0] # Check that all keys have the same first dimension for key in self.input_keys[1:]: if data[key].shape[0] != N: raise ValueError( f"All keys must have the same first dimension. " f"Key '{first_key}' has {N}, but '{key}' has {data[key].shape[0]}" ) # Skip if already fewer points than requested if N <= self.n_points: return data # Get weights if provided weights = None if self.weights_key is not None: if self.weights_key not in data.keys(): raise KeyError( f"Weights key '{self.weights_key}' not found in data. " f"Available keys: {list(data.keys())}" ) weights = data[self.weights_key] # Sample indices device = first_tensor.device if weights is not None: # Weighted sampling _, indices = shuffle_array( first_tensor, self.n_points, weights=weights, generator=self._generator, ) elif self.algorithm == "poisson_fixed" and N > 2**24: indices = poisson_sample_indices_fixed( N, self.n_points, device=device, generator=self._generator, replacement=False, ) else: # Use uniform sampling indices = torch.randperm( N, device=device, generator=self._generator, )[: self.n_points] # Apply indices to all keys updates = {} for key in self.input_keys: updates[key] = data[key][indices] return data.update(updates) def __repr__(self) -> str: """ Return string representation. Returns ------- str String representation of the transform. """ weights_str = f", weights_key={self.weights_key}" if self.weights_key else "" return ( f"SubsamplePoints(input_keys={self.input_keys}, n_points={self.n_points}, " f"algorithm={self.algorithm}{weights_str})" )