# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Subsampling transforms for point clouds and surfaces.
Provides efficient subsampling methods for large datasets, including
Poisson disk sampling and weighted sampling.
"""
from __future__ import annotations
from typing import Literal, Optional
import torch
from tensordict import TensorDict
from physicsnemo.datapipes.registry import register
from physicsnemo.datapipes.transforms.base import Transform
def poisson_sample_indices_fixed(
N: int,
k: int,
device=None,
generator: torch.Generator | None = None,
*,
replacement: bool = False,
) -> torch.Tensor:
"""
Near-uniform sampler of indices for very large arrays.
This function provides nearly uniform sampling for cases where the number
of indices is very large (> 2^24) and :func:`torch.multinomial` cannot work.
Unlike using :func:`torch.randperm`, there is no need to materialize and
randomize the entire tensor of indices.
The sampling uses exponentially distributed gaps to achieve near-uniform
coverage. Two modes are available:
- ``replacement=False`` (default): each gap is constrained to be at least
one index unit, so the resulting indices are strictly increasing and
therefore unique. Requires ``k < N`` strictly.
- ``replacement=True``: raw exponential gaps are used. The gaps can be
arbitrarily small, so consecutive indices may collide after flooring,
i.e. duplicates are possible.
Parameters
----------
N : int
Total number of available indices.
k : int
Number of indices to sample.
device : torch.device, optional
Device for the output tensor.
generator : torch.Generator, optional
Random generator for reproducibility.
replacement : bool, keyword-only, default=False
If ``False``, sample without replacement (no duplicate indices). If
``True``, sample with replacement (duplicates possible).
Returns
-------
torch.Tensor
Tensor of shape :math:`(k,)` containing sampled indices.
Raises
------
ValueError
If ``replacement=False`` and ``k >= N``, since sampling ``k`` unique
indices from ``N`` requires ``k < N``.
Examples
--------
>>> indices = poisson_sample_indices_fixed(1000000, 10000)
>>> print(indices.shape)
torch.Size([10000])
"""
if replacement:
# Draw exponential gaps off of random initializations
gaps = torch.rand(k, device=device, generator=generator).exponential_()
summed = gaps.sum()
# Normalize so total cumulative sum == N
gaps *= N / summed
# Compute cumulative positions
idx = torch.cumsum(gaps, dim=0)
# Shift down so range starts at 0 and ends below N
idx -= gaps[0] / 2
# Round to nearest integer index
idx = torch.clamp(idx.floor().long(), min=0, max=N - 1)
return idx
# Without-replacement path: enforce a minimum gap of 1 index unit so that
# flooring cumulative positions yields strictly increasing (unique) indices.
if k >= N:
raise ValueError(
f"poisson_sample_indices_fixed requires k < N when "
f"replacement=False, but got k={k} and N={N}."
)
# Draw exponential gaps off of random initializations. Use float64 for the
# cumulative arithmetic: with N up to ~1e8+, float32 precision near the top
# of the range (~8) would swallow the +1 minimum-gap shift and let
# consecutive floored positions collide.
gaps = torch.rand(
k, device=device, generator=generator, dtype=torch.float64
).exponential_()
# Scale so the raw exponential portion sums to (N - k); adding 1 to every
# gap then makes the minimum gap >= 1 while keeping the total sum == N.
gaps *= (N - k) / gaps.sum()
gaps += 1.0
# Compute cumulative positions; positions span [gaps[0], N] with each
# successive position differing by at least 1.
idx = torch.cumsum(gaps, dim=0)
# Shift down so the first index lands near 0 (still >= 0 since gaps[0] >= 1)
idx -= gaps[0]
# Floor to integer indices. Because every gap is >= 1 and the cumulative
# sum is computed in float64 (preserving the +1 separation even at large
# N), floor(pos_{i+1}) is strictly greater than floor(pos_i), so the
# resulting indices are unique.
idx = torch.clamp(idx.floor().long(), min=0, max=N - 1)
return idx
def shuffle_array(
points: torch.Tensor,
n_points: int,
weights: Optional[torch.Tensor] = None,
generator: torch.Generator | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Sample points with or without weights.
Parameters
----------
points : torch.Tensor
Input tensor to sample from, shape :math:`(N, ...)`.
n_points : int
Number of points to sample.
weights : torch.Tensor, optional
Optional weights for sampling, shape :math:`(N,)`.
If None, uses uniform sampling.
generator : torch.Generator, optional
Random generator for reproducibility.
Returns
-------
sampled_points : torch.Tensor
Sampled tensor, shape :math:`(n\\_points, ...)`.
indices : torch.Tensor
Selected indices, shape :math:`(n\\_points,)`.
"""
N = points.shape[0]
device = points.device
if N < n_points:
# If not enough points, return all points
indices = torch.arange(N, device=device)
return points, indices
if weights is not None:
# Weighted sampling
indices = torch.multinomial(
weights, n_points, replacement=False, generator=generator
)
else:
# Uniform sampling
if N > 2**24:
# Use Poisson sampling for very large arrays
indices = poisson_sample_indices_fixed(
N, n_points, device=device, generator=generator
)
else:
# Use standard multinomial for smaller arrays
indices = torch.randperm(N, device=device, generator=generator)[:n_points]
sampled_points = points[indices]
return sampled_points, indices
[docs]
@register()
class SubsamplePoints(Transform):
r"""
Subsample points from large point clouds or meshes.
This transform applies coordinated subsampling to multiple tensor fields,
ensuring that the same points are selected across all specified keys.
Useful for downsampling large volumetric data or point clouds while
maintaining correspondence between coordinates and field values.
Supports two sampling algorithms:
- ``"poisson_fixed"``: Near-uniform sampling for very large datasets (> 2^24 points)
- ``"uniform"``: Standard uniform sampling
Optionally supports weighted sampling (e.g., area-weighted for surface meshes)
by providing a ``weights_key``.
Parameters
----------
input_keys : list[str]
List of tensor keys to subsample. All must have the same
first dimension size.
n_points : int
Number of points to sample.
algorithm : {"poisson_fixed", "uniform"}, default="poisson_fixed"
Sampling algorithm to use.
weights_key : str, optional
Optional key for sampling weights (e.g., ``"surface_areas"``
for area-weighted surface sampling). When provided, samples
are drawn according to the weights distribution.
Examples
--------
Uniform sampling:
>>> transform = SubsamplePoints(
... input_keys=["volume_mesh_centers", "volume_fields"],
... n_points=10000,
... algorithm="poisson_fixed"
... )
>>> sample = TensorDict({
... "volume_mesh_centers": torch.randn(100000, 3),
... "volume_fields": torch.randn(100000, 5)
... })
>>> result = transform(sample)
>>> print(result["volume_mesh_centers"].shape)
torch.Size([10000, 3])
Weighted sampling:
>>> transform = SubsamplePoints(
... input_keys=["surface_mesh_centers", "surface_fields", "surface_normals"],
... n_points=5000,
... algorithm="uniform",
... weights_key="surface_areas"
... )
>>> sample = TensorDict({
... "surface_mesh_centers": torch.randn(20000, 3),
... "surface_fields": torch.randn(20000, 2),
... "surface_normals": torch.randn(20000, 3),
... "surface_areas": torch.rand(20000)
... })
>>> result = transform(sample)
>>> print(result["surface_mesh_centers"].shape)
torch.Size([5000, 3])
Notes
-----
All specified keys must have the same size in their first dimension.
The same indices are applied to all keys to maintain correspondence.
"""
def __init__(
self,
input_keys: list[str],
n_points: int,
*,
algorithm: Literal["poisson_fixed", "uniform"] = "poisson_fixed",
weights_key: Optional[str] = None,
) -> None:
"""
Initialize the subsample transform.
Parameters
----------
input_keys : list[str]
List of tensor keys to subsample. All must have the same
first dimension size.
n_points : int
Number of points to sample.
algorithm : {"poisson_fixed", "uniform"}, default="poisson_fixed"
Sampling algorithm to use.
weights_key : str, optional
Optional key for sampling weights (e.g., ``"surface_areas"``
for area-weighted surface sampling). When provided, samples
are drawn according to the weights distribution.
"""
super().__init__()
self.input_keys = input_keys
self.n_points = n_points
self.algorithm = algorithm
self.weights_key = weights_key
self._generator: torch.Generator | None = None
def __call__(self, data: TensorDict) -> TensorDict:
"""
Apply subsampling to the TensorDict.
Parameters
----------
data : TensorDict
Input TensorDict containing fields to subsample.
Returns
-------
TensorDict
TensorDict with subsampled fields.
Raises
------
KeyError
If a required key is not found in the data.
ValueError
If keys have inconsistent first dimension sizes.
"""
if not self.input_keys:
return data
# Check that all keys are present
for key in self.input_keys:
if key not in data.keys():
raise KeyError(
f"Key '{key}' not found in data. "
f"Available keys: {list(data.keys())}"
)
# Get the first key to determine indices
first_key = self.input_keys[0]
first_tensor = data[first_key]
N = first_tensor.shape[0]
# Check that all keys have the same first dimension
for key in self.input_keys[1:]:
if data[key].shape[0] != N:
raise ValueError(
f"All keys must have the same first dimension. "
f"Key '{first_key}' has {N}, but '{key}' has {data[key].shape[0]}"
)
# Skip if already fewer points than requested
if N <= self.n_points:
return data
# Get weights if provided
weights = None
if self.weights_key is not None:
if self.weights_key not in data.keys():
raise KeyError(
f"Weights key '{self.weights_key}' not found in data. "
f"Available keys: {list(data.keys())}"
)
weights = data[self.weights_key]
# Sample indices
device = first_tensor.device
if weights is not None:
# Weighted sampling
_, indices = shuffle_array(
first_tensor,
self.n_points,
weights=weights,
generator=self._generator,
)
elif self.algorithm == "poisson_fixed" and N > 2**24:
indices = poisson_sample_indices_fixed(
N,
self.n_points,
device=device,
generator=self._generator,
replacement=False,
)
else:
# Use uniform sampling
indices = torch.randperm(
N,
device=device,
generator=self._generator,
)[: self.n_points]
# Apply indices to all keys
updates = {}
for key in self.input_keys:
updates[key] = data[key][indices]
return data.update(updates)
def __repr__(self) -> str:
"""
Return string representation.
Returns
-------
str
String representation of the transform.
"""
weights_str = f", weights_key={self.weights_key}" if self.weights_key else ""
return (
f"SubsamplePoints(input_keys={self.input_keys}, n_points={self.n_points}, "
f"algorithm={self.algorithm}{weights_str})"
)