Quick Start Guide to cuDNN#

Note

This page is a short, code-first introduction to cuDNN. It shows a complete example of the same simple matrix multiplication workflow in Python and C++. For detailed explanations of cuDNN concepts, refer to the Developer Guide.

You can use cuDNN in Python or in C++.

Start from a small PyTorch linear layer on a GPU:

import torch
import torch.nn as nn

b, m, n, k = 16, 32, 64, 128
x = torch.randn(b, m, k, device="cuda", dtype=torch.bfloat16)
linear = nn.Linear(k, n, bias=False, device="cuda", dtype=torch.bfloat16)
y = linear(x)
# same:
# y = torch.matmul(x, linear.weight.T)

The tensors have shape [B, M, K] and a projection along the last axis to width N—a pattern common in LLMs (batch B, sequence M, hidden K, output width N). Conceptually, the operation is a batched matrix multiplication. PyTorch expresses that concept compactly and dispatches to its own kernels.

The PyTorch block shown previously is a baseline for the same math, not a suggestion that cuDNN should match it line-for-line. Frameworks typically hide the graph construction and much of the dtype and the layout detail that the cuDNN frontend asks you to state explicitly.

The following listing runs the same matrix multiplication through the cuDNN graph API.

import cudnn
import torch

def compare_results(actual: torch.Tensor, expected: torch.Tensor):
    rtol = 1e-2
    atol = 1e-2
    _b, m, n = actual.shape
    assert expected.shape == actual.shape

    # count the number of close elements
    close_mask = torch.isclose(actual, expected, atol=atol, rtol=rtol, equal_nan=True)
    num_el = actual.numel()
    close_cnt = close_mask.detach().sum().cpu().item()
    # find the max diff and location
    max_diff = (actual - expected).abs().max().cpu().item()
    max_diff_idx = (actual - expected).abs().argmax().cpu().item()
    max_diff_idx = (max_diff_idx // (m * n), max_diff_idx % (m * n) // n, max_diff_idx % n)

    print(f"Percentage of close elements: {100 * close_cnt / num_el:.1f}%")
    print(f"Max absolute difference: {max_diff}")
    print(f"At index {list(max_diff_idx)}"
        f"  GPU={actual[max_diff_idx]}, CPU={expected[max_diff_idx]}")


b, m, n, k = 16, 32, 64, 128
a_dev = torch.randn(b, m, k, device="cuda", dtype=torch.bfloat16)
b_dev = torch.randn(1, k, n, device="cuda", dtype=torch.bfloat16)

# Start of core cuDNN code
with cudnn.Graph(
    io_data_type=torch.bfloat16,
    compute_data_type=torch.float32,
    inputs=["matmul::A", "matmul::B"],
    outputs=["out"],
) as graph:
    c_cudnn = graph.matmul(name="matmul", A=a_dev, B=b_dev)
    c_cudnn.set_name("out").set_output(True)

handle = cudnn.create_handle()
c_dev = graph(a_dev, b_dev, handle=handle)
# End of core cuDNN code
c_ref = torch.matmul(a_dev.to(torch.float32), b_dev.to(torch.float32)).to(torch.bfloat16)

compare_results(c_dev, c_ref)

The cuDNN example stays at the matrix multiplication level (no nn.Module wrapper). It is more explicit than PyTorch about tensor roles and dtypes.

Core cuDNN Path

The with cudnn.Graph(...) context through c_dev = graph(...).

Supporting Checks

compare_results, the float32 torch.matmul reference, and the printout exist only to show that c_dev matches PyTorch within a loose numerical tolerance. They exist only as a confidence check, not as production overhead you must carry everywhere.

Beyond a single matrix multiplication, the graph-based API is how the frontend composes multiple operations and reuses a built graph. That workflow does not map one-to-one to a one-line nn.Linear call.

Graph Basics

A cuDNN graph is a DAG of ops over tensors. A built graph maps to CUDA work on the device. Support depends on the GPU and the library version. When defining a graph, specify at least:

  • Which tensors are the graph outputs

  • The I/O dtypes for the graph boundaries

  • The per-node compute dtype (often FP32 accumulation for matrix multiplication)

  • The intermediate dtypes between nodes (only when the graph has multiple nodes)

The sample sets default I/O and compute dtypes on the graph context. With a single matrix multiplication node, no intermediate dtype is needed. Shapes and strides for a_dev and b_dev are inferred from the PyTorch tensors passed into the graph builder. At definition time, only metadata is used. The tensor values are read at execution.

Build and Validate

When the cudnn.Graph context exits, the frontend validates and builds the graph. Invalid DAGs or unsupported patterns raise at exit. Output metadata for c_cudnn is filled in during that step even if it was not set explicitly.

Execution and Reuse

After the build, the graph object can be called repeatedly like a function. The inputs and outputs lists fix argument order. String names (for example matmul::A for argument A of node matmul, and out for the marked output) wire the graph ports to the call arguments. PyTorch tensors used only at build time still supply live buffers at run time.

The graph setup runs on the host. Reuse the built graph when shapes, layouts, and dtypes stay fixed, so setup cost is amortized. Pass a cudnn handle into the call (for example from cudnn.create_handle()) to select the device and stream context, or pass handle="auto" for a graph-owned handle (less sharing across graphs).

compare_results() reports how closely c_dev matches c_ref. c_ref uses float32 accumulation then casts back to bfloat16 so the comparison aligns with the graph’s mixed-precision matrix multiplication.

The C++ example uses the same batched matrix multiplication pattern as the Python example: tensors shaped [B, M, K] and [1, K, N], with bf16 on the device.

The device tensors use nv_bfloat16. The host-side reference math uses float because many host CPUs lack fast bf16 paths and std::bfloat16_t remains optional even in recent C++ standards.

Full Listing

#include <iostream>
#include <random>
#include <unordered_map>

#include <cuda_runtime.h>   // from $CUDA_HOME/include
#include <cudnn_frontend.h> // from $CUDNN_FRONTEND_HOME/include

void compare_results(float* actual, std::vector<float> expected, int64_t B, int64_t M, int64_t N) {
    float rtol = 1e-2f;
    float atol = 1e-2f;
    int64_t num_close = 0;
    float max_diff = 0.0f;
    int64_t max_diff_idx = 0;

    for (int64_t i = 0; i < B * M * N; ++i) {
        float diff = std::fabs(actual[i] - expected[i]);
        if (diff > max_diff) {
            max_diff = diff;
            max_diff_idx = i;
        };
        if (diff <= rtol * std::fabs(expected[i]) + atol) {
            num_close++;
        };
    };

    int64_t b_max = max_diff_idx / (M * N);
    int64_t m_max = (max_diff_idx % (M * N)) / N;
    int64_t n_max = max_diff_idx % N;

    std::cout << "Percentage of close elements: " << 100.0f * num_close / (B * M * N) << "%\n";
    std::cout << "Max absolute difference: " << max_diff << "\n";
    std::cout << "At index [" << b_max << ", " << m_max << ", " << n_max << "]"
            << "  GPU=" << actual[max_diff_idx] << "  CPU=" << expected[max_diff_idx] << "\n";
}

int main() {
    namespace fe = cudnn_frontend;

    //
    // 1. Problem dimensions
    //    A : [b, m, k]  (row-major)
    //    B : [1, k, n]  (row-major)
    //    C : [b, m, n]  (row-major, output)
    //
    int64_t const b = 16;
    int64_t const m = 32;
    int64_t const n = 64;
    int64_t const k = 128;

    //
    // 2. Allocate buffers and fill with random BF16 data
    //
    nv_bfloat16* A_dev;
    nv_bfloat16* B_dev;
    float* C_dev;
    cudaMallocManaged(&A_dev, b * m * k * sizeof(nv_bfloat16));
    cudaMallocManaged(&B_dev, 1 * k * n * sizeof(nv_bfloat16));
    cudaMallocManaged(&C_dev, b * m * n * sizeof(float));

    // Host copies kept for CPU reference computation.
    std::vector<float> A_host(b * m * k);
    std::vector<float> B_host(1 * k * n);

    // Initialize the CPU reference tensors with random values, then make a copy to the GPU.
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<float> dis(0.0f, 1.0f);
    for (auto& e : A_host) e = dis(gen);
    for (auto& e : B_host) e = dis(gen);

    for (int64_t i = 0; i < b * m * k; ++i) {
        A_dev[i] = __float2bfloat16(A_host[i]);
    }
    for (int64_t i = 0; i < 1 * k * n; ++i) {
        B_dev[i] = __float2bfloat16(B_host[i]);
    }
    cudaDeviceSynchronize();

    //
    // 3. Describe the graph
    //
    fe::graph::Graph graph{};
    graph.set_intermediate_data_type(fe::DataType_t::FLOAT)
         .set_compute_data_type(fe::DataType_t::FLOAT); // optional

    // Input tensor A: BF16, shape [b, m, k], row-major strides
    auto A = graph.tensor(fe::graph::Tensor_attributes()
                            .set_name("A")
                            .set_dim({b, m, k})
                            .set_stride({m * k, k, 1})
                            .set_data_type(fe::DataType_t::BFLOAT16));

    // Input tensor B: BF16, shape [b, k, n], row-major strides
    auto B = graph.tensor(fe::graph::Tensor_attributes()
                            .set_name("B")
                            .set_dim({1, k, n})
                            .set_stride({k * n, n, 1})
                            .set_data_type(fe::DataType_t::BFLOAT16));

    // Matmul attributes: compute data type is FP32 (accumulate in higher precision)
    auto matmul_attr = fe::graph::Matmul_attributes().set_compute_data_type(fe::DataType_t::FLOAT);

    // Output tensor C: FP32, shape inferred by the graph
    auto C = graph.matmul(A, B, matmul_attr);
    C->set_output(true).set_data_type(fe::DataType_t::FLOAT);

    //
    // 4. Validate and build the graph
    //
    cudnnHandle_t handle;
    cudnnCreate(&handle);
    auto status = graph.build(handle, {fe::HeurMode_t::A, fe::HeurMode_t::FALLBACK});
    if (!status.is_good()) {
        std::cerr << "build() failed: " << status.get_message() << "\n";
        return 1;
    };

    //
    // 5. Execute
    //
    void* workspace_dev = nullptr;
    int64_t workspace_size = graph.get_workspace_size();
    if (workspace_size > 0) {
        cudaMalloc(&workspace_dev, workspace_size);
    }

    // variant pack: map the cudnn tensor to device buffers
    std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
        {A, A_dev},
        {B, B_dev},
        {C, C_dev},
    };

    // execute the graph with handle, workspace, and variant pack
    status = graph.execute(handle, variant_pack, workspace_dev);
    if (!status.is_good()) {
        std::cerr << "graph.execute() failed: " << status.get_message() << "\n";
        return 1;
    }
    cudaDeviceSynchronize();

    std::cout << "Matmul executed successfully.\n";
    std::cout << "  A (bf16) : [" << b << ", " << m << ", " << k << "]\n";
    std::cout << "  B (bf16) : [" << 1 << ", " << k << ", " << n << "]\n";
    std::cout << "  C (fp32) : [" << b << ", " << m << ", " << n << "]\n";

    //
    // 6. Compute the CPU reference result
    //
    std::vector<float> C_cpu(b * m * n);
    for (int64_t bi = 0; bi < b; ++bi) {
        for (int64_t mi = 0; mi < m; ++mi) {
            for (int64_t ni = 0; ni < n; ++ni) {
                float acc = 0.0f;
                for (int64_t ki = 0; ki < k; ++ki) {
                    float a_val = A_host[bi * m * k + mi * k + ki];
                    float b_val = B_host[ki * n + ni];
                    acc += a_val * b_val;
                };
                C_cpu[bi * m * n + mi * n + ni] = acc;
            };
        };
    };

    compare_results(C_dev, C_cpu, b, m, n);

    //
    // 7. Clean up
    //
    cudnnDestroy(handle);
    cudaFree(A_dev);
    cudaFree(B_dev);
    cudaFree(C_dev);
    if (workspace_dev) {
        cudaFree(workspace_dev);
    }
    return 0;
}

The core cuDNN code is in Steps 3-5 in this example.

Includes and Roles

  • <cuda_runtime.h> covers allocation, synchronization, and device types.

  • <cudnn_frontend.h> supplies the graph API and helpers used below.

Tensor Layout

  • A_dev is logically [b, m, k].

  • B_dev is [1, k, n] and broadcasts over the batch of A.

  • C_dev is the result and is [b, m, n].

The row-major packing is implied by the stride tuples in the graph and not by pointer types alone.

Allocation

main fixes b,m,n,k, allocates unified memory for A_dev, B_dev, and C_dev, and mirrors random data in the host float vectors before downcasting to bf16 on the device.

nv_bfloat16* A_dev;
nv_bfloat16* B_dev;
float* C_dev;
cudaMallocManaged(&A_dev, b * m * k * sizeof(nv_bfloat16));
cudaMallocManaged(&B_dev, 1 * k * n * sizeof(nv_bfloat16));
cudaMallocManaged(&C_dev, b * m * n * sizeof(float));

The layout is a contract with the graph description: The same flat length can represent different logical shapes, keeping strides consistent with the intended math.

Host Reference Buffers

  • A_host and B_host store random float seeds.

  • The loops copy into A_dev and B_dev by using __float2bfloat16.

  • After host-to-device fills, cudaDeviceSynchronize() ensures the copies finish before graph work begins.

Graph Object

  1. Construct a fe::graph::Graph (default construction).

  2. Optionally set graph-wide dtype defaults.

    This sample chains set_intermediate_data_type and set_compute_data_type to FLOAT so accumulation stays in FP32.

fe::graph::Graph graph{};
graph.set_intermediate_data_type(fe::DataType_t::FLOAT)
     .set_compute_data_type(fe::DataType_t::FLOAT);

set_compute_data_type(FLOAT) forces FP32 accumulation for the matrix multiplication. The same knob can be set per node instead of on the graph.

Ports and Operation

  1. Declare tensors through graph.tensor.

  2. Attach a Matmul_attributes object.

  3. Wire graph.matmul.

  4. Mark the output tensor and dtype on C.

The graph-level defaults apply unless a node overrides them.

Build

  1. Create a cudnnHandle_t.

  2. Call graph.build with a heuristic list (see main in the listing).

  3. Check status.is_good() before execution.

  4. Use status.get_message() on failure.

For simplicity, this example uses a single-call build path instead of an explicit validate/plan pipeline.

Workspace

get_workspace_size() reports the scratch requirement. Allocate the device memory when the size is non-zero.

void* workspace_dev = nullptr;
int64_t workspace_size = graph.get_workspace_size();
if (workspace_size > 0) {
    cudaMalloc(&workspace_dev, workspace_size);
}

The scratch buffers typically use plain device allocations. Unified memory is unnecessary because the workspace is not inspected from the host.

Variant Pack

  1. Map each Tensor_attributes shared pointer to its live device pointer.

  2. Call graph.execute with the handle, variant pack, and workspace.

std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
    {A, A_dev}, {B, B_dev}, {C, C_dev},
};
status = graph.execute(handle, variant_pack, workspace_dev);

The outputs land in the buffers bound in the variant pack (here, C_dev). Check the returned status the same way as for the earlier steps.

Host Reference and Check

  1. The nested loops fill C_cpu in float32 accumulation space.

  2. compare_results reports closeness compared to C_dev.

Teardown

  1. Free the device buffers and workspace.

  2. Destroy the cuDNN handle created for the build/execute path.

cudnnDestroy(handle);
cudaFree(A_dev);
cudaFree(B_dev);
cudaFree(C_dev);
if (workspace_dev) {
    cudaFree(workspace_dev);
}

Compile

Point nvcc at the cuDNN frontend headers, CUDA headers, and the cuDNN library directory. Example:

nvcc -std=c++17 -I path/to/cudnn_frontend/include -I path/to/cudnn/include -L path/to/cudnn/lib -lcudnn -lnvrtc -o simple_matmul simple_matmul.cpp

Before launching the binary, ensure that the cuDNN shared library directory is in LD_LIBRARY_PATH (Linux) or the equivalent search path.

# optional to set LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/path/to/cudnn/lib
# run the executable
./simple_matmul

For the full build and run instructions, see Build and run cuDNN.