What can I help you with?
NVIDIA PhysicsNeMo Core (Latest Release)
NVIDIA Docs Hub  NVIDIA PhysicsNeMo  NVIDIA PhysicsNeMo Core (Latest Release)  Torch Compile and External Kernels

Torch Compile and External Kernels

The Challenge: Getting the Best of Both Worlds

Scientific AI applications often face a performance dilemma: torch.compile can accelerate PyTorch models, but many scientific workloads require specialized libraries outside the PyTorch ecosystem—like RAPIDS cuML for accelerated k-nearest neighbors, NVIDIA Warp for differentiable physics simulation, or other domain-specific kernels. Using these external libraries typically causes “graph breaks” in torch.compile, limiting the potential performance benefits.

Note

PyTorch deployed torch.compile in version 2.0 - older versions of pytorch will not be compatible with this tutorial.

This tutorial demonstrates how to solve this challenge by integrating external kernels with torch.compile using PyTorch’s custom operator API, enabling you to leverage both PyTorch’s graph optimization and high-performance external libraries simultaneously.

What You’ll Learn

By the end of this tutorial, you’ll understand how to:

  • Register external library functions as PyTorch custom operators

  • Enable torch.compile to work seamlessly with libraries like cuML and Warp

  • Implement custom backward passes for external kernels

  • Share memory pools between PyTorch and RAPIDS for additional performance gains

  • Achieve significant speedups (10x+ in our toy examples) without sacrificing torch.compile benefits

Who is this for?

This is a more advanced tutorial, for AI developers who are actively working on new models, applications, data pipelines, etc. Strong familiarity with torch is a prerequisite, and basic familiarity with unstructed data operations (like k-Nearest Neighbors) is good to have. And, you should know the basics of how torch.compile works and how to use it on your code.

Table of Contents

What does torch.compile do?

If you’re interested in torch.compile, you’ve probably already found the tutorial from PyTorch. At a high level, torch.compile is a tool that allows pytorch to inspect your model ahead of time, find places where kernels can be optimized or combined, and enable those optimizations. The performance gain is heavily dependent on the application: kernel fusion (like a convolution + activation) can help reduce runtime by mitigating the memory-bound characteristics of one kernel when fusing it to compute bound kernel. Further, performance gains are highly dependent on compute precision as well: the thresholds for what is “compute-bound” and what is “memory-bound” are different depending on the precision. Lower precisions can take advantage of smaller memory footprints (so less bandwidth is necessary from memory) as well as dedicated processing units like Tensor cores for faster math operations.

With all of that in mind - that tutorial is focused on pure PyTorch functionality. In PhysicsNeMo workloads, however, we often need to leverage tools that live outside the pytorch ecosystem. But with large, complex, and end-to-end models, we still want to take advantage of the performance benefits we can get with torch.compile. So in the rest of this tutorial, we’ll look at exactly how to solve that problem.

This tutorial is broken into two models: first, we’ll work on a k-Nearest-Neighbors type problem, which we can accelerate with cuml. Second, we’ll do a closer examination of the backwards-pass functionality in torch.compile (and you’ll learn why it wasn’t necessary in the first example, even though we’re doing training!)

Introducing the Application

For demonstration purposes, we’ve invented a small operator that works on point-cloud like data. This means the input to the operator is a , 2D torch.Tensor of 3D points, unstructured and unordered. You can find small scale point-cloud data, for example, in “ModelNet-10 - Princeton 3D Object Dataset” (link) - but the exact features of the data aren’t the focus here in this tutorial, so don’t worry about the details of the input/outputs.

Regarding the application - for PhysicsNeMo users, you’ll recognize similar ideas in architectures such as DoMINO and FigConvNet, and local aggregation of points is a well studied topic in many graph neural networks on point clouds. We’re not specifically using these models, however, this is a fully independent example application.

We start with a simple, 3-layer MLP:

Copy
Copied!
            

            
class MLP(torch.nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
A simple 3 layer MLP that takes in a tensor of
shape (N, input_dim) and outputs a tensor of
shape (N, output_dim)
"""
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

This MLP is used twice, in a simple model:

Copy
Copied!
            

            
class kNN_Projector(torch.nn.Module):
    def __init__(self, k: int, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.proj = MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=hidden_dim)
        self.k = k

        self.proj_out = MLP(input_dim=hidden_dim, hidden_dim = hidden_dim, output_dim = output_dim)

    def forward(self, p1: torch.Tensor, p2: torch.Tensor) -> torch.Tensor:
        """
Accept two point clouds, p1 and p2.  Compute a learnable projection onto p2 to
learn features.  Then, use a kNN-weighted aggregation to project those features
onto p1.
"""
        p2_features = self.proj(p2)

        p1_features = knn_weighted_feature_aggregation(p1, p2, p2_features, k=self.k)

        return self.proj_out(p1_features)

In basic terms, this model operates on two sets of point clouds. A reference set of points, p2, has some features learned on it by the first MLP. Then, using the k nearest neighbors to each point in p1, the features in p2 are projected (knn_weighted_feature_aggregation) onto the locations in p1. Finally, the output features from the aggregation are projected to a final latent space via a second MLP. The details of the projection look like this:

Copy
Copied!
            

            
def knn_weighted_feature_aggregation(
    p1: torch.Tensor,
    p2: torch.Tensor,
    p2_features: torch.Tensor,
    k: int = 3,
    sigma: float = 0.1,
    eps: float = 1e-8
    ) -> torch.Tensor:
    """
Perform differentiable kNN-weighted feature aggregation.

Args:
p1 (torch.Tensor): Query points, shape (B, M, D)
p2 (torch.Tensor): Reference points, shape (B, N, D)
p2_features (torch.Tensor): Features at reference points, shape (B, N, D_feat)
k (int): Number of neighbors
sigma (float): RBF temperature parameter
eps (float): Numerical stability for normalization

Returns:
torch.Tensor: Aggregated features at p1, shape (B, M, D_feat)
"""
    # M, D = p1.shape
    # N, D_feat = p2_features.shape

    # Compute pairwise distances: (M, N)
    dists = torch.norm(p1[:,None,:] - p2[None,:,:], dim=-1)


    # Find top-k nearest neighbors
    topk_dists, topk_idx = torch.topk(dists, k=k, dim=1, largest=False)

    # Gather neighbor features: (M, k, D_feat)
    neighbors = p2_features[topk_idx]

    # Compute weights: (M, k)
    weights = torch.softmax(-topk_dists / sigma, dim=1)

    # Weighted sum of neighbor features: (M, D_feat)
    agg = torch.sum(weights.unsqueeze(-1) * neighbors, dim=1)

    return agg

You make recognize this as a brute-force implementation of a kNN, followed by a weight calculation based on how far apart two points are.

See also

Don’t read into the high level algorithm too closely! Remember, we’re here in this tutorial to talk about computational performance. This is just a made-up example that uses a kNN.

Just for completeness, so you can run this example on your own, here are some helper functions needed to initialize the data and, optionally, ensure deterministic inputs:

Copy
Copied!
            

            
def generate_data(N_points_to_search, grid_points, target_features, dtype=torch.bfloat16):
    device = torch.device("cuda")


    # Make a random point cloud:
    point_cloud = torch.randn(N_points_to_search, 3, device=device, requires_grad=False, dtype=dtype)

    # And this is a set of 3D points on a grid, that we'll flatten:
    x = torch.linspace(-1, 1, 30, device=device, dtype=dtype)
    y = torch.linspace(-1, 1, 30, device=device, dtype=dtype)
    z = torch.linspace(-1, 1, 30, device=device, dtype=dtype)

    # Create 3D meshgrid
    X, Y, Z = torch.meshgrid(x, y, z, indexing='ij')

    # Flatten and stack to get grid points as (N, 3) tensor
    grid_points = torch.stack([X.flatten(), Y.flatten(), Z.flatten()], dim=1)

    grid_features = torch.randn(grid_points.shape[0], target_features, device=device, requires_grad=False, dtype=dtype)

    return point_cloud, grid_points, grid_features

def set_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def train_step(model, optimizer, grid_points, point_cloud, grid_features):
    # Pretent to train the model!
    optimizer.zero_grad()
    output = model.forward(grid_points, point_cloud)
    loss = torch.mean((output - grid_features)**2)
    loss.backward()
    optimizer.step()
    return loss

To run this, you’ll need to use a function like this to measure the performance. Note the presence of the PhysicsNeMo profiler to quickly and easily enable pytorch profiling. You’ll want from physicsnemo.utils.profiling import Profiler at the top level of your python script (along with import torch!)

Copy
Copied!
            

            
def measure_performance(model, inputs, warmup_iters, benchmark_iters, profile=False):

    grid_points, point_cloud, grid_features = inputs

    profiler = Profiler()
    if profile:
        profiler.enable("torch")

    # Make a dummy optimizer:
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Warm up:
    for i in range(warmup_iters):
        # Forward only:
        model.forward(grid_points, point_cloud)

        # Training:
        loss = train_step(model, optimizer, grid_points, point_cloud, grid_features)


    torch.cuda.synchronize()

    with profiler:


        with torch.no_grad():
            torch.cuda.synchronize()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            start_event.record()
            # Benchmark the forward pass
            for i in range(benchmark_iters):
                output =model.forward(grid_points, point_cloud)
            end_event.record()
            torch.cuda.synchronize()

        print(f"Time taken in forward:{start_event.elapsed_time(end_event) / benchmark_iters:.3f}ms per iteration")

        # Benchmark the training loop:

        torch.cuda.synchronize()
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        start_event.record()
        # Benchmark the backward pass
        for i in range(benchmark_iters):
            loss = train_step(model, optimizer, grid_points, point_cloud, grid_features)

        end_event.record()
        torch.cuda.synchronize()


        print(f"Time taken in backward:{start_event.elapsed_time(end_event) / benchmark_iters:.3f}ms per iteration")

Finally, we can execute the script like this:

Copy
Copied!
            

            
if __name__ == "__main__":

    set_seed(42)

    target_features = 1
    n_grid_points = 30
    n_cloud_points = 100000
    dtype = torch.float32

    point_cloud, grid_points, grid_features = generate_data(n_cloud_points, n_grid_points, target_features, dtype=dtype)
    print(point_cloud.shape)
    print(grid_points.shape)
    print(grid_features.shape)

    model = kNN_Projector(k=7, hidden_dim=25, output_dim=target_features).cuda().to(dtype)

    warmup_iters = 5
    benchmark_iters = 15

    measure_performance(model, (grid_points, point_cloud, grid_features), warmup_iters, benchmark_iters, profile=False)

On an A100 GPU, we see performance like this:

Copy
Copied!
            

            
torch.Size([100000, 3])
torch.Size([27000, 3])
torch.Size([27000, 1])
Time taken in forward: 144.045 ms per iteration
Time taken in backward: 144.758 ms per iteration

And, by introducing model = torch.compile(model) and no other changes, performance jumps by a factor of two:

Copy
Copied!
            

            
Time taken in forward: 74.237 ms per iteration
Time taken in backward: 74.657 ms per iteration

Why? It’s interesting to explore exactly what happened, here, to enable a 2x performance boost in this pretend application. If you run this application with profiling on, and look at the two profiles (with and without compilation) you’ll see pretty clearly some top kernels.

See also

Want to learn more about how to profile your pytorch code? Check out our profiling tutorial: Profiling Applications in PhysicsNeMo

Uncompiled Application Performance (Top Operations):

Name

Self CUDA

Self CUDA %

CUDA total

CUDA time avg

CUDA Mem

Self CUDA Mem

# of Calls

aten::topk 1.864s 41.69% 1.864s 62.133ms 64.91 Mb 64.19 Mb 30
aten::linalg_vector_norm 1.670s 37.36% 1.670s 55.672ms 301.76 Gb 301.76 Gb 30
reduce_kernel<512, 1, ...> 1.670s 37.36% 1.670s 3.480ms 0 b 0 b 480
radixFindKthValues<float, ...> 1.125s 25.17% 1.125s 9.377ms 0 b 0 b 120
aten::sub 904.517ms 20.23% 904.517ms 20.100ms 905.27 Gb 905.27 Gb 45

Compiled Application Performance (Top Operations):

Name

Self CUDA

Self CUDA %

CUDA total

CUDA time avg

CUDA Mem

Self CUDA Mem

# of Calls

aten::topk 1.863s 83.53% 1.863s 54.802ms 64.91 Mb 64.91 Mb 34
radixFindKthValues<float, ...> 1.125s 50.42% 1.125s 9.373ms 0 b 0 b 120
Torch-Compiled Region: 0/0 0.000us 0.00% 1.112s 74.163ms 436.00 Mb 0 b 15
CompiledFunction 0.000us 0.00% 1.112s 74.163ms 436.00 Mb 403.54 Mb 15
gatherTopK<float, unsigned> 682.819ms 30.61% 682.819ms 22.761ms 0 b 0 b 30
triton_poi_fused_linalg_vector_norm_sub_0 352.928ms 15.82% 352.928ms 11.764ms 0 b 0 b 30

Take note of the top kernels before compilation: aten::topk was (and still is) dominant. But right before we call topk in user code, we compute the norm of all the points together in the point cloud: aten::linalg_vector_norm takes 55ms and it’s significantly less in the compiled version (and, it shows up under a different name!) This doesn’t account for all of the difference, though it’s a lot. To learn more about understanding the profiling results, check out Profiling Applications in PhysicsNeMo.

Improving Performance with RAPIDS

Now, we’ve seen that torch.compile can accelerate our code, but if we step back at think about the kNN algorithm, we’ll realize it’s not ideal. We are computing this with an N*M algorithm (every point in p1 compared to every point in p2) - and it’s expensive particularly in memory usage. Better algorithms exist - and it’s not the subject of this tutorial to get into them - and we already have a good example in Nvidia’s RAPIDS ecosystem: cuML Nearest Neighbors. These days, integrating into pytorch is straightforward. Update our knn_weighted_feature_aggregation function:

Copy
Copied!
            

            
def knn_weighted_feature_aggregation(
    p1: torch.Tensor,
    p2: torch.Tensor,
    p2_features: torch.Tensor,
    k: int = 3,
    sigma: float = 0.1,
    eps: float = 1e-8
    ) -> torch.Tensor:
    """
"""


    # Find top-k nearest neighbors (Make sure to cast to float32 for cuml)
    topk_dists, topk_idx = knn_search_with_cuml(p1.to(torch.float32), p2.to(torch.float32), k)

    # Gather neighbor features: (M, k, D_feat)
    neighbors = p2_features[topk_idx]

    # Compute weights: (M, k)
    weights = torch.softmax(-topk_dists / sigma, dim=1)

    # Weighted sum of neighbor features: (M, D_feat)
    agg = torch.sum(weights.unsqueeze(-1) * neighbors, dim=1)

    # Cast back to original dtype
    return agg.to(p1.dtype)

The difference is: replace the pointwise norm call with a knn_search_with_cuml call, and then directly get the neighbors based on the index. The rest is the same. As for the knn_search_with_cuml function, it does the real heavy lifting with calls to cuml:

Copy
Copied!
            

            
def knn_search_with_cuml(p1: torch.Tensor, p2: torch.Tensor, k: int = 3):
    # Use dlpack to move the data without copying between pytorch and cuml:
    p1 = cp.from_dlpack(p1)
    p2 = cp.from_dlpack(p2)

    # Construct the knn:
    knn = cuml.neighbors.NearestNeighbors(n_neighbors=k)
    # First pass partitions everything in p2 to make lookups fast
    knn.fit(p2)

    # Second pass uses that partition to quickly find neighbors of points in p1
    distance, indices = knn.kneighbors(p1)

    # convert back to pytorch:
    distance = torch.from_dlpack(distance)
    indices = torch.from_dlpack(indices)

    # Return torch objects.
    return distance, indices

A couple things to note about this function: it’s pytorch in, pytorch out. We’ve encapsulated all cuml contact to one region of code, which will be useful later. Second, this function returns the distances and the indexes, which are both used, but the gradient in knn_weighted_feature_aggregation will flow through the output selected features, through the distance-weighted aggregation, and then through the neighbors = p2_features[topk_idx] line. The topk_idx directs which indexes the gradients flow to but they are not themselves differentiable. Likewise, the topk_dists tensor provides weights for gradeints in the backwards pass, but is itself not expecting gradients. So: knn_search_with_cuml does not need to have a derivative implementation, and the backwards pass of this model just works. It “just works” quite well, too:

Copy
Copied!
            

            
Time taken in forward: 14.139 ms per iteration
Time taken in backward: 15.842 ms per iteration

Now, if you try to compile this you will hit a warning:

Copy
Copied!
            

            
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:700: UserWarning: Graph break due to unsupported builtin cupy._core.dlpack.from_dlpack. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.

And, performance is a little worse:

Copy
Copied!
            

            
Time taken in forward: 15.697 ms per iteration
Time taken in backward: 17.670 ms per iteration

The issue of course is our function, knn_search_with_cuml, is calling operations that pytorch has no idea what to do with.

Registering External Ops With PyTorch

However, if you follow along with the PyTorch Custom Ops Tutorial, it’s not hard to see how to extend this. We have to register the function with pytorch:

Copy
Copied!
            

            
@torch.library.custom_op("cuml::knn", mutates_args=())
def knn_search_with_cuml(p1: torch.Tensor, p2: torch.Tensor, k: int = 3) -> tuple[torch.Tensor, torch.Tensor]:
    p1 = cp.from_dlpack(p1)
    p2 = cp.from_dlpack(p2)

    knn = cuml.neighbors.NearestNeighbors(n_neighbors=k)
    knn.fit(p2)

    distance, indices = knn.kneighbors(p1)

    # convert back to pytorch:
    distance = torch.from_dlpack(distance)
    indices = torch.from_dlpack(indices)

    return distance, indices

And, we have to define a “fake” tensor function for this function: based on the inputs, it tells pytorch what the outputs will look like. It’s easily done with a decorator:

Copy
Copied!
            

            
@knn_search_with_cuml.register_fake
def _(p1, p2, k):
    assert p1.device == p2.device

    dist_output = torch.empty(p1.shape[0], k, device=p1.device, dtype=p1.dtype)
    idx_output = torch.empty(p1.shape[0], k, device=p1.device, dtype=torch.int64)

    return dist_output, idx_output

Note

We don’t even need to name this function. It’s consumed and registered with PyTorch, and PyTorch takes care of the rest.

With these changes, now torch.compile will work! You won’t actually see a significant speedup, though - in fact you’ll probably see negligible change in performance (< 1ms difference). The challenge, here, is that while the cuml implementation is much much faster, it includes cuda synchronize calls - which block execution on the GPU. Since the rest of the model is so tiny, the compilation does almost nothing to improve it: we’re bound now by kernel launch latency outside of that call. You can - and should! - run the profiles and take a look to see that the GPU is now significantly more idle than it was in the first iteration of the code. However, for real models, with much deeper and larger layers, which will not be a major issue.

Performance Bonus! Shared Memory Pools

If you do look at the profile, you’ll see a lot of memory operations in the cuml region of the code. Why? It has to allocate memory for itself, and while both RAPIDS and PyTorch have dedicated memory management tools to accelerate this, they are not using the same pool of memory. Fortunately, PyTorch easily allows you to swap in another memory allocator tool, and RAPID’s memory mananger is easy to plug in. Add these to your imports:

Copy
Copied!
            

            
import rmm
from rmm.allocators.torch import rmm_torch_allocator

And, before you initialize any data or models in pytorch, plug the RAPIDS memory managemer into pytorch:

Copy
Copied!
            

            
rmm.reinitialize(pool_allocator=True)
torch.cuda.memory.change_current_allocator(rmm_torch_allocator)

This improves the runtime by a further ~3ms (which is > 20% faster on an already good speedup!):

Copy
Copied!
            

            
# Not Compiled:
Time taken in forward: 11.185 ms per iteration
Time taken in backward: 12.818 ms per iteration

# Compiled:
Time taken in forward: 11.163 ms per iteration
Time taken in backward: 12.571 ms per iteration

What about the backwards pass?

Implementing Custom Backward Passes with NVIDIA Warp

Yes - the backwards pass is supported too. Our original example didn’t actually need the backwards pass of the kNN operator. Let’s write a new kernel that does need one so we can see how this backwards incorporation works. Instead of a kNN, we’ll use all points within a specified radius to compute the same distance weighted feature aggregation. In pytorch, it looks like this:

Copy
Copied!
            

            
def radius_bounded_feature_aggregation(
    p1: torch.Tensor,
    p2: torch.Tensor,
    p2_features: torch.Tensor,
    radius: float,
    sigma: float,
    ) -> torch.Tensor:
    """
Perform differentiable radius-bounded feature aggregation.

Args:
p1 (torch.Tensor): Query points, shape (B, M, D)
p2 (torch.Tensor): Reference points, shape (B, N, D)
p2_features (torch.Tensor): Features at reference points, shape (B, N, D_feat)
radius (float): Radius for neighbor search
sigma (float): RBF temperature parameter

Returns:
torch.Tensor: Aggregated features at p1, shape (B, M, D_feat)
"""

    # Compute pairwise distances: (M, N)
    dists = torch.norm(p1[:,None,:] - p2[None,:,:], dim=-1)

    # Create mask for neighbors within radius
    mask = dists <= radius

    # Compute weights from all distances first
    weights = torch.softmax(-dists / sigma, dim=-1)

    # Apply mask to zero out weights for points outside radius
    weights = torch.where(mask, weights, torch.zeros_like(weights))

    # Renormalize weights so they sum to 1 for each query point
    weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-8)

    # Weighted sum of all reference features: (M, D_feat)
    agg = torch.sum(weights.unsqueeze(-1) * p2_features.unsqueeze(0), dim=1)

    return agg

Everything else about the pretend application can stay the same: an MLP on the features before the aggregation, an MLP on the aggregated features. You may need to decrease the number of points in the point cloud though - the pure-torch version of this operator is more memory hungry than the kNN query.

With 25k points in the point cloud, and grids of 30x30x30 points (27k), the uncompiled timing looks like this:

Copy
Copied!
            

            
Time taken in forward: 161.396 ms per iteration
Time taken in backward: 260.719 ms per iteration

While compiled, we have a bit better performance:

Copy
Copied!
            

            
Time taken in forward: 38.718 ms per iteration
Time taken in backward: 86.226 ms per iteration

To implement this better, we’ll turn to NVIDIA’s Warp kernel language. Warp is designed for differentiable physics simulation - and for this kernel, we can write it directly in Warp and it will generate the adjoint (aka, the backward pass) for us automatically. The trick is that we can write a much more efficient version using Warp’s HashGrid (docs) object.

Copy
Copied!
            

            
@wp.kernel
def warp_radius_bounded_feature_aggregation(
    query_points: wp.array(dtype=wp.vec3),
    ref_points:  wp.array(dtype=wp.vec3),
    ref_features: wp.array2d(dtype=wp.float32),
    grid: wp.uint64,
    radius: float,
    sigma: float,
    output_features: wp.array2d(dtype=wp.float32),
):
    # Get the thread ID:
    tid = wp.tid()

    # Get position from query points
    pos = query_points[tid]

    feature_dim = ref_features.shape[1]
    local_output = output_features[tid]

    # Find all the neighbors using the hash grid:
    neighbors = wp.hash_grid_query(id=grid, point=pos, max_dist=radius)

    weight = float(0.0)
    # Loop through neighbors. Compute a weighted distance and accumulate it,
    # but also track the weight to normalize at the end.
    for index in neighbors:
        # Get the neighbor position:
        pos2 = ref_points[index]

        # Compute the distance:

        dist = wp.length(pos - pos2)
        # disp = pos - pos2
        # dist = wp.sqrt(disp[0]**2. + disp[1]**2. + disp[2]**2.)

        if dist > radius:
            continue

        # Get the features at this index:
        feature = ref_features[index]

        # Compute the weight:
        this_weight = wp.exp( - dist / sigma)

        # Accumulate the weight, and weight * feature
        weight += this_weight
        # Work directly with the 2D array indexing instead of the slice
        for j in range(feature_dim):
            local_output[j] += feature[j] * this_weight

    if weight > 0.0:
        # Normalize by working directly with the 2D array
        for j in range(feature_dim):
            local_output[j] /= (weight + 1e-8)

    # Write the output
    for j in range(feature_dim):
        output_features[tid,j] = local_output[j]

Note

The function above is a warp kernel, not a python function. It’s written in python, but in reality it’s compiled to CUDA and launched like any other kernel - but all from python itself. Read more about warp kernels here

This kernel replaces nearly the entirety of the torch code: it finds all the points within the radius and then directly accumulates them into the output. Launching a warp kernel from pytorch is straightforward, due to Warp’s interoperability with several other languages:

Copy
Copied!
            

            
@torch.library.custom_op("warp::radius_bounded_feature_aggregation", mutates_args=())
def radius_bounded_feature_aggregation_impl(
    query_points: torch.Tensor,
    ref_points: torch.Tensor,
    ref_features: torch.Tensor,
    radius: float,
    sigma: float,
) -> torch.Tensor:


    # Convert to warp
    # We can only build and query the points in wp.vec3 format:
    wp_query_points = wp.from_torch(
        query_points,
        dtype=wp.vec3, # vec3 here!
        requires_grad=query_points.requires_grad
    )
    wp_ref_points = wp.from_torch(
        ref_points,
        dtype=wp.vec3, # and here!
        requires_grad=ref_points.requires_grad
    )
    wp_ref_features = wp.from_torch(
        ref_features,
        dtype=wp.float32, # but, the features we keep as an array!
        requires_grad=ref_features.requires_grad
    )


    # In the data generation, we had set the grid to range over -1 to 1 with 30 points.
    # We can use that to dictate the grid size:
    grid_size = int(2. / radius) + 1

    # **In general** to do this, you'd have to incur a performance penalty:
    # 1. Find the min max of all the points in your query set.
    # 2. Divide the range by the radius to get the grid size.
    # 3. Move the grid size to the CPU (which is blocking if done wrong)
    # 4. Construct the grid on the CPU.

    # But, it's **likely** you'd be constructing the grid once and caching it, anyways.



    # Build the grid used in the kernel:
    hash_grid = wp.HashGrid(grid_size, grid_size, grid_size)

    # This actually loops over the points and does the hashing
    hash_grid.build(wp_ref_points, radius)


    # Allocate output space (with pytorch!) and convert to warp:
    output_features = torch.zeros(
        (query_points.shape[0], ref_features.shape[1], ),
        device=ref_features.device,
        dtype=ref_features.dtype,
        requires_grad=ref_features.requires_grad
    )

    wp_output = wp.from_torch(
        output_features,
        dtype=wp.float32,
        requires_grad=ref_features.requires_grad
    )

    # Launch the kernel:
    feature_dim = ref_features.shape[1]
    wp.launch(
        warp_radius_bounded_feature_aggregation,
        inputs=[
            wp_query_points,
            wp_ref_points,
            wp_ref_features,
            hash_grid.id,
            radius,
            sigma,
        ],
        outputs =[
            wp_output,
        ],
        dim=[query_points.shape[0]],
    )


    # return the output features:
    return output_features

For more details on Warp, it’s interface, and tools available, head over to their documentation.

If you followed along with the first half, you’ll recognize the declaration to the function: @torch.library.custom_op("warp::radius_bounded_feature_aggregation", mutates_args=()). The fake registration is still necessary too:

Copy
Copied!
            

            
@radius_bounded_feature_aggregation_impl.register_fake
def _(
    query_points: torch.Tensor,
    ref_points: torch.Tensor,
    ref_features: torch.Tensor,
    radius: float,
    sigma: float,
) -> torch.Tensor:

    assert query_points.is_cuda
    assert ref_points.is_cuda
    assert ref_features.is_cuda

    output = torch.empty(
        (query_points.shape[0], ref_features.shape[1], ),
        device=query_points.device,
        dtype=query_points.dtype)

    return output

Now, we do the fun part: going backwards. The actual backwards pass to launch the warp kernel is very similar to the forwards pass:

Copy
Copied!
            

            
@torch.library.custom_op("warp::radius_bounded_feature_aggregation_bwd", mutates_args=())
def radius_bounded_feature_aggregation_bwd_impl(
    query_points: torch.Tensor,
    ref_points: torch.Tensor,
    ref_features: torch.Tensor,
    radius: float,
    sigma: float,
    output_features: torch.Tensor,
    grad_outputs : torch.Tensor,
) -> torch.Tensor:

    # This function only needs to get the gradients of p2_features,
    # based on the grad_output of the forward outputs.
    # Everything else is None for a grad.

    # Convert to warp:
    wp_query_points = wp.from_torch(
        query_points,
        dtype=wp.vec3,
        requires_grad=False
    )
    # We can only build and query the points in float32:
    wp_ref_points = wp.from_torch(
        ref_points,
        dtype=wp.vec3,
        requires_grad=False
    )
    #########################################################
    # Because we set requires_grad True here, warp will
    # populate gradients HERE in the .grad attribute.
    #########################################################
    wp_ref_features = wp.from_torch(
        ref_features,
        dtype=wp.float32,
        requires_grad=True,
    )

    # In the data generation below, we set the grid to range over -1 to 1 with 30 points.
    # We can use that to dictate the grid size:
    grid_size = int(2. / radius) + 1

    # Build the grid used in the kernel:
    # In a real application, you'd cache and retrieve this in the backwards pass.
    # The trick would be to make sure the hash_grid object persists but is not
    # actually in the torch interface (it would break things for the compiler!)
    hash_grid = wp.HashGrid(grid_size, grid_size, grid_size)

    # We're rebuilding here just to make the implementation straightforward
    hash_grid.build(wp_ref_points, radius)

    wp_output = wp.from_torch(output_features)

    wp_grad_outputs = wp.from_torch(grad_outputs)


    feature_dim = ref_features.shape[1]
    # Launch the kernel:
    wp.launch(
        warp_radius_bounded_feature_aggregation,
        inputs=[
            wp_query_points,
            wp_ref_points,
            wp_ref_features,
            hash_grid.id,
            radius,
            sigma,
        ],
        outputs =[
            wp_output,
        ],
        adj_inputs = [
            None,
            None,
            wp_ref_features.grad,
            None,
            None,
            None,
        ],
        adj_outputs = [
            wp_grad_outputs,
        ],
        adjoint=True,  ############ Pay attention here! Launch the kernel adjoint!
        dim=[query_points.shape[0]],
    )

    # return the gradient of the features:
    return ref_features.grad

Registering the fake is also straightforward and just like the forward pass:

Copy
Copied!
            

            
@radius_bounded_feature_aggregation_bwd_impl.register_fake
def _(
    query_points: torch.Tensor,
    ref_points: torch.Tensor,
    ref_features: torch.Tensor,
    radius: float,
    sigma: float,
    output_features: torch.Tensor,
    grad_outputs : torch.Tensor
) -> torch.Tensor:
    grad_outputs = torch.empty_like(ref_features)
    return grad_outputs

Connecting all the pieces

So far, we haven’t actually told PyTorch this is a backwards pass function. Instead, we’ve just declared a similar function, that happens to have bwd in the name. For the autograd system in PyTorch, we need two more steps to enable this all to connect. First, we have to set up the context, which is useful to save the forward pass objects for the backwards calculations. And second, we need to tell PyTorch exactly which function is the backwards pass and which function is the context setup - a simple one-liner to register them in the autograd system:

Copy
Copied!
            

            
# Use this to save any inputs or outputs for the backward pass.
def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output):
    query_points, ref_points, ref_features, radius, sigma = inputs
    out_features = output
    ctx.save_for_backward(query_points, ref_points, ref_features, out_features)

    ctx.radius = radius
    ctx.sigma = sigma

# And this connects the forward pass to the backward pass with pytorch.
radius_bounded_feature_aggregation_impl.register_autograd(radius_bounded_feature_aggregation_backward_worker, setup_context=setup_context)

Note

If you’ve written a torch.autograd.Function subclass before, you’re familiar with the context though it may look different to you in this form. PyTorch recommends this syntax to enable their functional API as well, but, if you don’t need that you are welcome to use your tried and true inheritance from torch.autograd.Function.

If you’re not familiar with how PyTorch does autograd work - check out their excellent guide here

The rest of the application proceeds just like before. With Warp, though, the performance is even better than compiled pytorch:

Copy
Copied!
            

            
Time taken in forward: 13.511 ms per iteration
Time taken in backward: 29.953 ms per iteration

Success!

Conclusion

In this tutorial, we saw some cool features of integrating highly performant code into pytorch applications, and how to combine them with torch.compile. Some key takeaways:

  • torch.compile is generally great for performance. Use it unless there is a reason you can’t!

  • If the reason you can’t use torch.compile is that you have to leave the pytorch ecosystem to get a better performing kernel - then use this method to register your wrapper functions and enable the compiler to seamlessly incorporate them.

  • If you’re calling out to other libraries in the NVIDIA ecosystem, like RAPIDS, you can get even better performance sharing a memory manager! Check out the RAPIDS Memory Manager, which plugs in to pytorch (and also cupy and numba!)

  • You can just as easily incorporate a backwards pass as a forwards pass. In this tutorial we leveraged the autograd capabilites of warp, but in other cases you can of course write the backwards pass yourself.

Without these techniques, for many scientific AI workloads users are faced with a choice: use torch.compile on their model, including inefficient functions; or, skip torch.compile and leverage fast, accelerated calls in the NVIDIA python ecosystem. The reality is, though, that you don’t need to make that choice: you can have both performance enhancements.
Previous Performance
Next Domain Parallelism
© Copyright 2023, NVIDIA PhysicsNeMo Team. Last updated on Jun 11, 2025.
content here