Adding Torch Custom Ops in cuDNN Frontend#
Best practices for wrapping cuDNN graph ops as PyTorch custom ops with minimal CPU overhead.
File location#
Custom ops live in python/cudnn/experimental/ops/. Each op gets its own file
(e.g., sdpa.py, rmsnorm.py, layernorm.py, moe.py). Export from
python/cudnn/experimental/ops/__init__.py.
Registration: use torch.Library, NOT @torch.library.custom_op#
@torch.library.custom_op adds ~36 us per-call overhead vs ~10 us for direct
torch.Library registration (PyTorch issue #139500). Always use the direct API:
import torch
_lib = torch.library.Library("cudnn", "DEF")
# Define schema — must list all tensor and scalar args with types
_lib.define(
"my_op(Tensor x, Tensor w, float eps, bool training=False, "
"Tensor? bias=None) -> (Tensor, Tensor)"
)
# Register CUDA implementation
def _my_op_impl(x, w, eps, training=False, bias=None):
# ... build/cache graph, execute ...
return output, aux
_lib.impl("my_op", _my_op_impl, "CUDA")
# Register fake (for torch.compile shape inference)
@torch.library.register_fake("cudnn::my_op")
def _my_op_fake(x, w, eps, training=False, bias=None):
return torch.empty_like(x), torch.empty(...)
# Register autograd (if backward is needed)
torch.library.register_autograd("cudnn::my_op", _my_op_backward, setup_context=_my_op_setup_ctx)
User calls via: torch.ops.cudnn.my_op(x, w, eps) or wrap in a public function.
Graph caching#
Build graphs once per unique (shape, stride, dtype, config) tuple. Cache as module-level dict:
_cache: Dict[tuple, tuple] = {}
def _make_cache_key(x, w, eps, has_bias):
return (
"my_op",
tuple(x.shape), tuple(x.stride()), x.dtype,
tuple(w.shape), tuple(w.stride()), w.dtype,
eps, has_bias, x.device,
)
Include device in the key — different GPUs may get different engine plans.
Graph building pattern#
Use tensor_like() for automatic shape/stride/dtype inference from DLPack tensors:
graph = cudnn.pygraph(handle=handle, ...)
X = graph.tensor_like(x)
W = graph.tensor_like(w)
# ... chain ops ...
Y.set_output(True).set_data_type(cudnn.data_type.BFLOAT16)
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph.check_support()
graph.build_plans() # also prepares variant pack template
Execute: use sorted_ptrs path#
Cache uid_order alongside the graph. Use _execute_with_ptrs instead of dict-based execute:
if cache_key not in _cache:
graph, ws_size = _build_graph(...)
uid_order = graph._get_variant_pack_uids_sorted()
_cache[cache_key] = (graph, ws_size, uid_order)
graph, ws_size, uid_order = _cache[cache_key]
# Allocate workspace per-call (PyTorch's caching allocator recycles efficiently)
workspace = torch.empty(max(ws_size, 1), device=x.device, dtype=torch.uint8)
# Build uid→tensor map, extract sorted ptrs
uid_to_tensor = {X.get_uid(): x, W.get_uid(): w, Y.get_uid(): y_out}
ptrs = [uid_to_tensor[uid].data_ptr() for uid in uid_order]
graph._execute_with_ptrs(ptrs, workspace.data_ptr(), int(handle))
Do NOT cache workspace tensors — they can race on different CUDA streams. Allocate per-call; PyTorch’s caching allocator recycles the allocation without hitting cudaMalloc (~2 us overhead). This matches PyTorch’s own conv/SDPA pattern.
UID management#
Use an IntEnum for explicit UIDs — makes the code self-documenting and cache keys stable:
from enum import IntEnum
class _UIDs(IntEnum):
X = 1
W = 2
BIAS = 3
Y = 100
# ...
Set UIDs explicitly when building the graph:
X = graph.tensor(dim=..., stride=..., data_type=..., uid=int(_UIDs.X))
Or use tensor_like which auto-assigns UIDs in insertion order (1, 2, 3, …).
Handle management#
Cache one handle per device. Call set_stream once per call (costs ~5 us):
_handles = {}
def _get_handle(device):
if device not in _handles:
_handles[device] = cudnn.create_handle()
cudnn.set_stream(handle=_handles[device], stream=torch.cuda.current_stream(device).cuda_stream)
return _handles[device]
Future optimization: skip set_stream if stream hasn’t changed since last call.
Autograd pattern#
def _setup_context(ctx, inputs, output):
x, w, eps, training, bias = inputs
y, aux = output
ctx.save_for_backward(x, w, y, aux)
ctx.eps = eps
# save non-tensor args as ctx attributes
def _backward(ctx, dY, d_aux):
x, w, y, aux = ctx.saved_tensors
dX, dW = torch.ops.cudnn.my_op_bwd(dY, x, w, y, aux, ctx.eps)
return dX, dW, None, None, None # None for non-differentiable args
torch.library.register_autograd("cudnn::my_op", _backward, setup_context=_setup_context)
Public API wrapper#
Provide a user-friendly function that matches PyTorch conventions:
def my_op(x, w, eps=1e-5, bias=None):
"""cuDNN-accelerated my_op. Matches torch.nn.functional.my_op API."""
# input validation
# ...
y, _aux = torch.ops.cudnn.my_op(x, w, eps, bias=bias)
return y
Performance checklist#
[ ] Use
torch.Library.define/impl, not@torch.library.custom_op[ ] Cache graph + uid_order + workspace in module-level dict
[ ] Use
graph._execute_with_ptrs(sorted_ptrs)notgraph.execute(dict)[ ] Use explicit UIDs (IntEnum) for stable cache keys
[ ] Cache cuDNN handle per device
[ ] Allocate workspace per-call (do NOT cache — stream safety)
[ ] Consider
out=param for output tensors in performance-critical paths[ ] Include device in cache key
[ ] Test with tiny tensors to measure CPU overhead in isolation
CPU overhead budget (per call, Blackwell release build)#
Component |
Cost (us) |
Notes |
|---|---|---|
torch.ops dispatch + autograd |
~25 |
PyTorch-side, unavoidable with torch.Library |
|
~5 |
pybind11 cross-language call |
|
~2 each |
consider |
cache key build + lookup |
~1.5 |
tuple construction + dict hash |
uid→tensor dict + list comp |
~1 |
Python overhead |
|
~19 |
1.7 us FE + 0.8 us varpack + 5.6 us backend |
Total (well-optimized) |
~52 |
vs ~18 us for native ATen ops |
The ~34 us gap vs native ATen is torch.ops dispatcher + autograd overhead.
For inference without torch.compile, calling _my_op_impl directly saves ~25 us.