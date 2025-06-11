Now, we’ve seen that torch.compile can accelerate our code, but if we step back at think about the kNN algorithm, we’ll realize it’s not ideal. We are computing this with an N*M algorithm (every point in p1 compared to every point in p2 ) - and it’s expensive particularly in memory usage. Better algorithms exist - and it’s not the subject of this tutorial to get into them - and we already have a good example in Nvidia’s RAPIDS ecosystem: cuML Nearest Neighbors. These days, integrating into pytorch is straightforward. Update our knn_weighted_feature_aggregation function:

Copy Copied! def knn_weighted_feature_aggregation( p1: torch.Tensor, p2: torch.Tensor, p2_features: torch.Tensor, k: int = 3, sigma: float = 0.1, eps: float = 1e-8 ) -> torch.Tensor: """ """ # Find top-k nearest neighbors (Make sure to cast to float32 for cuml) topk_dists, topk_idx = knn_search_with_cuml(p1.to(torch.float32), p2.to(torch.float32), k) # Gather neighbor features: (M, k, D_feat) neighbors = p2_features[topk_idx] # Compute weights: (M, k) weights = torch.softmax(-topk_dists / sigma, dim=1) # Weighted sum of neighbor features: (M, D_feat) agg = torch.sum(weights.unsqueeze(-1) * neighbors, dim=1) # Cast back to original dtype return agg.to(p1.dtype)

The difference is: replace the pointwise norm call with a knn_search_with_cuml call, and then directly get the neighbors based on the index. The rest is the same. As for the knn_search_with_cuml function, it does the real heavy lifting with calls to cuml:

Copy Copied! def knn_search_with_cuml(p1: torch.Tensor, p2: torch.Tensor, k: int = 3): # Use dlpack to move the data without copying between pytorch and cuml: p1 = cp.from_dlpack(p1) p2 = cp.from_dlpack(p2) # Construct the knn: knn = cuml.neighbors.NearestNeighbors(n_neighbors=k) # First pass partitions everything in p2 to make lookups fast knn.fit(p2) # Second pass uses that partition to quickly find neighbors of points in p1 distance, indices = knn.kneighbors(p1) # convert back to pytorch: distance = torch.from_dlpack(distance) indices = torch.from_dlpack(indices) # Return torch objects. return distance, indices

A couple things to note about this function: it’s pytorch in, pytorch out. We’ve encapsulated all cuml contact to one region of code, which will be useful later. Second, this function returns the distances and the indexes, which are both used, but the gradient in knn_weighted_feature_aggregation will flow through the output selected features, through the distance-weighted aggregation, and then through the neighbors = p2_features[topk_idx] line. The topk_idx directs which indexes the gradients flow to but they are not themselves differentiable. Likewise, the topk_dists tensor provides weights for gradeints in the backwards pass, but is itself not expecting gradients. So: knn_search_with_cuml does not need to have a derivative implementation, and the backwards pass of this model just works. It “just works” quite well, too:

Copy Copied! Time taken in forward: 14.139 ms per iteration Time taken in backward: 15.842 ms per iteration

Now, if you try to compile this you will hit a warning:

Copy Copied! /usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:700: UserWarning: Graph break due to unsupported builtin cupy._core.dlpack.from_dlpack. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.

And, performance is a little worse:

Copy Copied! Time taken in forward: 15.697 ms per iteration Time taken in backward: 17.670 ms per iteration

The issue of course is our function, knn_search_with_cuml , is calling operations that pytorch has no idea what to do with.

However, if you follow along with the PyTorch Custom Ops Tutorial, it’s not hard to see how to extend this. We have to register the function with pytorch:

Copy Copied! @torch.library.custom_op("cuml::knn", mutates_args=()) def knn_search_with_cuml(p1: torch.Tensor, p2: torch.Tensor, k: int = 3) -> tuple[torch.Tensor, torch.Tensor]: p1 = cp.from_dlpack(p1) p2 = cp.from_dlpack(p2) knn = cuml.neighbors.NearestNeighbors(n_neighbors=k) knn.fit(p2) distance, indices = knn.kneighbors(p1) # convert back to pytorch: distance = torch.from_dlpack(distance) indices = torch.from_dlpack(indices) return distance, indices

And, we have to define a “fake” tensor function for this function: based on the inputs, it tells pytorch what the outputs will look like. It’s easily done with a decorator:

Copy Copied! @knn_search_with_cuml.register_fake def _(p1, p2, k): assert p1.device == p2.device dist_output = torch.empty(p1.shape[0], k, device=p1.device, dtype=p1.dtype) idx_output = torch.empty(p1.shape[0], k, device=p1.device, dtype=torch.int64) return dist_output, idx_output

Note We don’t even need to name this function. It’s consumed and registered with PyTorch, and PyTorch takes care of the rest.

With these changes, now torch.compile will work! You won’t actually see a significant speedup, though - in fact you’ll probably see negligible change in performance (< 1ms difference). The challenge, here, is that while the cuml implementation is much much faster, it includes cuda synchronize calls - which block execution on the GPU. Since the rest of the model is so tiny, the compilation does almost nothing to improve it: we’re bound now by kernel launch latency outside of that call. You can - and should! - run the profiles and take a look to see that the GPU is now significantly more idle than it was in the first iteration of the code. However, for real models, with much deeper and larger layers, which will not be a major issue.

If you do look at the profile, you’ll see a lot of memory operations in the cuml region of the code. Why? It has to allocate memory for itself, and while both RAPIDS and PyTorch have dedicated memory management tools to accelerate this, they are not using the same pool of memory. Fortunately, PyTorch easily allows you to swap in another memory allocator tool, and RAPID’s memory mananger is easy to plug in. Add these to your imports:

Copy Copied! import rmm from rmm.allocators.torch import rmm_torch_allocator

And, before you initialize any data or models in pytorch, plug the RAPIDS memory managemer into pytorch:

Copy Copied! rmm.reinitialize(pool_allocator=True) torch.cuda.memory.change_current_allocator(rmm_torch_allocator)

This improves the runtime by a further ~3ms (which is > 20% faster on an already good speedup!):