Neighbor Functionals#
KNN#
- physicsnemo.nn.functional.knn(
- points: Float[Tensor, 'num_points dim'],
- queries: Float[Tensor, 'num_queries dim'],
- k: int,
Perform a k-nearest neighbor search on torch tensors. Can be done with torch directly, or leverage RAPIDS cuML algorithm.
Auto-dispatch selects the optimal version for the input tensor device.
- Parameters:
points (torch.Tensor) – Tensor of shape (N, D) containing the points to search from.
queries (torch.Tensor) – Tensor of shape (M, D) containing the points to search for.
k (int) – Number of nearest neighbors to return for each query point.
implementation ({"cuml", "torch", "scipy"} or None) – Implementation to use for the search. When
None, the preferred implementation for the input device is selected and falls back to torch when unavailable.
- Returns:
indices (torch.Tensor) – Tensor of shape (M, k) containing the indices of the k nearest neighbors for each query point.
distances (torch.Tensor) – Tensor of shape (M, k) containing the distances to the k nearest neighbors for each query point.
Radius Search#
- physicsnemo.nn.functional.radius_search(
- points: Float[Tensor, '*batch num_points 3'],
- queries: Float[Tensor, '*batch num_queries 3'],
- radius: float,
- max_points: int | None = None,
- return_dists: bool = False,
- return_points: bool = False,
Performs radius-based neighbor search to find points within a specified radius of query points.
Can use brute-force methods with PyTorch, or an accelerated spatial decomposition method with Warp.
Accepts both unbatched inputs of shape
(N, 3)and batched inputs of shape(B, N, 3). When unbatched inputs are provided, they are treated asB=1internally and the batch dimension is stripped from the output. Only ranks 2 and 3 are accepted; higher-rank inputs raiseValueError.This function has differing behavior based on the argument for max_points. If max_points is None, the function will find ALL points within the radius and return a flattened list of indices, (optionally) distances, and (optionally) points. For unbatched inputs the indices will have a shape of
(2, N)where N is the aggregate number of neighbors found for all queries. The 0th row is the query index and the 1st row is the point index. For batched inputs the indices will have shape(3, N)where the 0th row is the batch index, the 1st row is the query index, and the 2nd row is the point index.If max_points is not None, the function will find the max_points closest points within the radius and return a statically sized array of indices, (optionally) distances, and (optionally) points. For unbatched inputs the indices will have shape
(Q, max_points). For batched inputs the indices will have shape(B, Q, max_points). Unused slots are filled with 0.Because the shape when max_points=None is dynamic, this function is incompatible with torch.compile in that case. When max_points is set, this function is compatible with torch.compile regardless of backend.
The different backends are not necessarily certain to provide identical output, for two reasons: first, if max_points is lower than the number of neighbors found, the selected points may be stochastic. Second, when max_points is None or max_points is greater than the number of neighbors, the outputs may be ordered differently by the two backends. Do not rely on the exact order of the neighbors in the outputs.
Note
With the Warp backend, there will be an automatic casting of inputs to float32 from reduced precision, and results will be returned in their original precision.
- Parameters:
points (torch.Tensor) – The reference point cloud tensor of shape
(N, 3)or(B, N, 3).queries (torch.Tensor) – The query points tensor of shape
(M, 3)or(B, M, 3).radius (float) – The search radius. Points within or at this radius of a query point will be considered neighbors.
max_points (int | None, optional) – Maximum number of neighbors to return for each query point. If None, returns all neighbors within radius. Defaults to None. See documentation for details.
return_dists (bool, optional) – If True, returns the distances to the neighbor points. Defaults to False.
return_points (bool, optional) – If True, returns the actual neighbor points in addition to their indices. Defaults to False.
implementation (str, optional) – Explicit implementation name (“warp” or “torch”). Defaults to None, which selects by rank.
- Returns:
Neighbor indices are always returned first. Additional tensors are appended when requested: -
indices(always): Neighbor indices -points(optional): Neighbor points whenreturn_points=True-distances(optional): Neighbor distances whenreturn_dists=True- Return type:
tuple | torch.Tensor
- Raises:
KeyError – If an explicit implementation name is not registered.
ImportError – If the selected implementation is unavailable.
ValueError – If inputs are not rank 2 or 3.
Benchmarks (ASV)