Neighbor Functionals#
KNN#
- physicsnemo.nn.functional.knn(
- points: Float[Tensor, 'num_points dim'],
- queries: Float[Tensor, 'num_queries dim'],
- k: int,
Perform a k-nearest neighbor search on torch tensors. Can be done with torch directly, or leverage RAPIDS cuML algorithm.
Auto-dispatch selects the optimal version for the input tensor device.
- Parameters:
points (torch.Tensor) – Tensor of shape (N, D) containing the points to search from.
queries (torch.Tensor) – Tensor of shape (M, D) containing the points to search for.
k (int) – Number of nearest neighbors to return for each query point.
implementation ({"cuml", "torch", "scipy"} or None) – Implementation to use for the search. When
None, the preferred implementation for the input device is selected and falls back to torch when unavailable.
- Returns:
indices (torch.Tensor) – Tensor of shape (M, k) containing the indices of the k nearest neighbors for each query point.
distances (torch.Tensor) – Tensor of shape (M, k) containing the distances to the k nearest neighbors for each query point.
Radius Search#
- physicsnemo.nn.functional.radius_search(
- points: Float[Tensor, 'num_points 3'],
- queries: Float[Tensor, 'num_queries 3'],
- radius: float,
- max_points: int | None = None,
- return_dists: bool = False,
- return_points: bool = False,
Performs radius-based neighbor search to find points within a specified radius of query points.
Can use brute-force methods with PyTorch, or an accelerated spatial decomposition method with Warp.
This function does not currently accept a batch index.
This function has differing behavior based on the argument for max_points. If max_points is None, the function will find ALL points within the radius and return a flattened list of indices, (optionally) distances, and (optionally) points. The indices will have a shape of (2, N) where N is the aggregate number of neighbors found for all queries. The 0th index of the output represents the index of the query points, and the 1st index represents the index of the neighbor points within the search space.
If max_points is not None, the function will find the max_points closest points within the radius and return a statically sized array of indices, (optionally) distances, and (optionally) points. The indices will have a shape of (queries.shape[0], max_points). Each row i of the indices will be neighbors of queries[i]. If there are fewer points than max_points, then the unused indices will be set to 0 and the distances and points will be set to 0 for unused points.
Because the shape when max_points=None is dynamic, this function is incompatible with torch.compile in that case. When max_points is set, this function is compatible with torch.compile regardless of backend.
The different backends are not necessarily certain to provide identical output, for two reasons: first, if max_points is lower than the number of neighbors found, the selected points may be stochastic. Second, when max_points is None or max_points is greater than the number of neighbors, the outputs may be ordered differently by the two backends. Do not rely on the exact order of the neighbors in the outputs.
Note
With the Warp backend, there will be an automatic casting of inputs to float32 from reduced precision, and results will be returned in their original precision.
- Parameters:
points (torch.Tensor) – The reference point cloud tensor of shape (N, 3) where N is the number of points.
queries (torch.Tensor) – The query points tensor of shape (M, 3) where M is the number of query points.
radius (float) – The search radius. Points within or at this radius of a query point will be considered neighbors.
max_points (int | None, optional) – Maximum number of neighbors to return for each query point. If None, returns all neighbors within radius. Defaults to None. See documentation for details.
return_dists (bool, optional) – If True, returns the distances to the neighbor points. Defaults to False.
return_points (bool, optional) – If True, returns the actual neighbor points in addition to their indices. Defaults to False.
implementation (str, optional) – Explicit implementation name (“warp” or “torch”). Defaults to None, which selects by rank.
- Returns:
Neighbor indices are always returned first. Additional tensors are appended when requested: -
indices(always): Neighbor indices -points(optional): Neighbor points whenreturn_points=True-distances(optional): Neighbor distances whenreturn_dists=True- Return type:
tuple | torch.Tensor
- Raises:
KeyError – If an explicit implementation name is not registered.
ImportError – If the selected implementation is unavailable.
Benchmarks (ASV)