For AI agents: a documentation index is available at the root level at /llms.txt and /llms-full.txt. Append /llms.txt to any URL for a page-level index, or .md for the markdown version of any page.
GitHubCUDA-X
    • Home
    • Installation
  • Getting Started
    • Introduction
    • Integrations
    • Use-cases
  • User Guide
    • API Guide
    • Benchmarking Guide
    • Compatibility
    • Integration Patterns
    • Advanced Topics
    • References
  • Developer Guide
    • Coding Guidelines
    • ABI Stability
    • Link-time Optimization
    • Contributing
  • API Reference
    • C API Documentation
    • Cpp API Documentation
    • Python API Documentation
    • Java API Documentation
    • Rust API Documentation
    • Go API Documentation
NVIDIANVIDIA
Developer-friendly docs for your API
Privacy Policy | Manage My Privacy | Do Not Sell or Share My Data | Terms of Service | Accessibility | Corporate Policies | Product Security | Contact

Copyright © 2026, NVIDIA Corporation.

LogoLogocuVS
GitHubCUDA-X
On this page
  • Background
  • What is JIT LTO?
  • Fragment Terminology
  • How It Works
  • Walkthrough Example
  • Step 1: Define the Kernel and Device Functions
  • Step 2: Create Device Function Fragments
  • Step 3: Create JSON Matrix Files
  • compute_distance_matrix.json
  • filter_matrix.json
  • search_kernel_matrix.json
  • Step 4: Create .cu.in Template Files
  • compute_distance_kernel.cu.in
  • filter_kernel.cu.in
  • Update search_kernel.cuh with Extern Declarations
  • search_kernel.cu.in
  • Step 5: Create Fragment Tags for Embedding
  • Step 6: Create the Planner
  • Step 7: Integrate with Code Path
  • Step 7b: Example — NVRTC UDFs for compute_distance and apply_filter
  • Key Concepts
  • Fragment Tags
  • Registration Tags
  • AlgorithmLauncher
  • Best Practices
  • Example: IVF Flat
  • Step 8: Integrate with CMake Build System
  • Summary
  • Fragment Architecture
Developer Guide

Link-time Optimization

||View as Markdown|
Previous

ABI Stability

Next

Contributing

Background

What is JIT LTO?

JIT LTO (Just-In-Time Link-Time Optimization) is a CUDA compilation strategy that enables dynamic kernel compilation and linking at runtime. Instead of pre-compiling all possible kernel variants (which would result in an explosion of binary size), JIT LTO compiles kernel fragments separately and links them together on-demand when a specific kernel configuration is needed.

Fragment Terminology

A fragment is a self-contained, compilable unit of CUDA code that can be linked with other fragments to form a complete kernel. In the JIT LTO system:

  • Entrypoint Fragment: The main kernel function that serves as the entry point. This is always the __global__ kernel function.
  • Device Function Fragments: Separate fragments containing device functions (e.g., distance computations, filters, post-processing) that are called by the entrypoint kernel.
  • Fragment Key: A unique identifier for a fragment, typically constructed from template parameters and configuration values.
  • Fatbin: The compiled binary representation of a fragment, embedded in the executable.

The key advantage is that device functions can be compiled independently and reused across multiple kernel entrypoints, reducing compilation time and binary size.

How It Works

  1. Build Time: Fragments are compiled into fatbins and embedded in the executable.
  2. Runtime: When a kernel needs to be launched:
    • The planner identifies which fragments are needed based on the configuration
    • Fragments are loaded from the embedded fatbins
    • Nvjitlink (Link-Time Optimization) links the fragments together
    • The linked kernel is cached and launched

Walkthrough Example

Let’s walk through creating a JIT LTO kernel system for a search kernel with templated device functions.

Step 1: Define the Kernel and Device Functions

We start with a kernel that has templated device functions that we want to separate into fragments:

search_kernel.cuh:

1#pragma once
2
3#include <cuda_runtime.h>
4
5namespace example::detail {
6
7// Device function for distance computation
8template <typename T>
9__device__ float compute_distance_euclidean(T a, T b) {
10 T diff = a - b;
11 return diff * diff;
12}
13
14template <typename T>
15__device__ float compute_distance_inner_product(T a, T b) {
16 return -a * b; // Negative for max inner product search
17}
18
19// Device function for filtering
20template <typename IdxT>
21__device__ bool apply_filter_none(uint32_t query_id, IdxT node_id, void* filter_data) {
22 return true;
23}
24
25template <typename IdxT>
26__device__ bool apply_filter_bitset(uint32_t query_id, IdxT node_id, void* filter_data) {
27 // Simplified - actual implementation would check bitset
28 return true;
29}
30
31// Main kernel - will use generic extern device functions
32template <typename T, typename OutT, typename IdxT, bool UseOptimizedPath, int Veclen>
33__device__ void search_kernel_impl(
34 const T* dataset,
35 const T* queries,
36 IdxT* results,
37 OutT* distances, // Output distance type
38 uint32_t num_queries,
39 uint32_t dataset_size,
40 void* filter_data) {
41
42 uint32_t query_id = blockIdx.x * blockDim.x + threadIdx.x;
43 if (query_id >= num_queries) return;
44
45 OutT best_dist = std::numeric_limits<OutT>::max();
46 IdxT best_idx = 0;
47
48 for (IdxT i = 0; i < dataset_size; ++i) {
49 // Call generic extern device functions (implementations linked from fragments)
50 if (!apply_filter<IdxT>(query_id, i, filter_data)) continue;
51
52 OutT dist = static_cast<OutT>(compute_distance<T>(queries[query_id], dataset[i]));
53
54 // Use optimized path if enabled
55 if constexpr (UseOptimizedPath) {
56 // Optimized implementation
57 if (dist < best_dist) {
58 best_dist = dist;
59 best_idx = i;
60 }
61 } else {
62 // Standard implementation
63 if (dist < best_dist) {
64 best_dist = dist;
65 best_idx = i;
66 }
67 }
68 }
69
70 results[query_id] = best_idx;
71 distances[query_id] = best_dist;
72}
73
74} // namespace example::detail

Step 2: Create Device Function Fragments

We’ll create separate header files for each device function variant. Each implements the generic function signature that the kernel expects:

compute_distance_euclidean.cuh:

1#pragma once
2
3namespace example::detail {
4
5// Implements the generic compute_distance function for euclidean distance
6template <typename T>
7__device__ float compute_distance(T a, T b) {
8 T diff = a - b;
9 return diff * diff;
10}
11
12} // namespace example::detail

compute_distance_inner_product.cuh:

1#pragma once
2
3namespace example::detail {
4
5// Implements the generic compute_distance function for inner product
6template <typename T>
7__device__ float compute_distance(T a, T b) {
8 return -a * b; // Negative for max inner product search
9}
10
11} // namespace example::detail

filter_none.cuh:

1#pragma once
2
3namespace example::detail {
4
5// Implements the generic apply_filter function for no filtering
6template <typename IdxT>
7__device__ bool apply_filter(uint32_t query_id, IdxT node_id, void* filter_data) {
8 return true;
9}
10
11} // namespace example::detail

filter_bitset.cuh:

1#pragma once
2
3namespace example::detail {
4
5// Implements the generic apply_filter function for bitset filtering
6template <typename IdxT>
7__device__ bool apply_filter(uint32_t query_id, IdxT node_id, void* filter_data) {
8 // Actual bitset implementation
9 return true;
10}
11
12} // namespace example::detail

Step 3: Create JSON Matrix Files

JSON matrix files define all the parameter combinations that need to be compiled. The build system uses these to generate .cu files from .cu.in templates.

How JSON Cross-Product Works:

  • The build system computes a modified Cartesian product (cross-product) of all parameter combinations.
  • Leaf nodes are the actual values. These can be strings, numbers, booleans, or null, but only strings should be used, even for numbers, for example "1".
  • Related values can be grouped together in a dictionary consisting of single values. Any dictionary key in such a dictionary’s ancestry will not be used in the final product, and should be prefixed with _ to indicate that it is used only for grouping.
  • Keys containing only leaf nodes will be used in the final product, and should not be prefixed with _.
  • The matrix product algorithm will automatically warn if the proper naming convention (_ prefix or not) is not followed.
  • Each group expands to create multiple combinations, and all groups are cross-multiplied.

For example, if you have:

1{
2 "_data_type": [{"data_type": "float"}, {"data_type": "half"}],
3 "_index": [{"idx_type": "uint32_t"}, {"idx_type": "int64_t"}],
4 "capacity": ["1", "2"]
5}

This generates 2 × 2 × 2 = 8 combinations:

  • {data_type: "float", idx_type: "uint32_t", capacity: "1"}
  • {data_type: "float", idx_type: "uint32_t", capacity: "2"}
  • {data_type: "float", idx_type: "int64_t", capacity: "1"}
  • … and so on

When a group contains nested arrays (like veclen: ["1", "4"]), those are also expanded within that group before the cross-product is computed.

compute_distance_matrix.json

1{
2 "_distance_type": [
3 {
4 "distance_name": "euclidean",
5 "header_file": "example/jit_lto_kernels/compute_distance_euclidean.cuh"
6 },
7 {
8 "distance_name": "inner_product",
9 "header_file": "example/jit_lto_kernels/compute_distance_inner_product.cuh"
10 }
11 ],
12 "_data_type": [
13 {
14 "data_type": "float",
15 "type_abbrev": "f"
16 },
17 {
18 "data_type": "__half",
19 "type_abbrev": "h"
20 }
21 ]
22}

filter_matrix.json

1{
2 "filter_name": [
3 "filter_none",
4 "filter_bitset"
5 ],
6 "_index": [
7 {
8 "idx_type": "uint32_t",
9 "idx_abbrev": "ui"
10 },
11 {
12 "idx_type": "int64_t",
13 "idx_abbrev": "l"
14 }
15 ]
16}

search_kernel_matrix.json

This example demonstrates conditional combinations: OutT can be float or double when T is float, but only float when T is __half.

1{
2 "_data_type": [
3 {
4 "data_type": "float",
5 "type_abbrev": "f",
6 "_output_type": [
7 {
8 "out_type": "float",
9 "out_abbrev": "f"
10 },
11 {
12 "out_type": "double",
13 "out_abbrev": "d"
14 }
15 ]
16 },
17 {
18 "data_type": "__half",
19 "type_abbrev": "h",
20 "_output_type": [
21 {
22 "out_type": "float",
23 "out_abbrev": "f"
24 }
25 ]
26 }
27 ],
28 "_index": [
29 {
30 "idx_type": "uint32_t",
31 "idx_abbrev": "ui"
32 },
33 {
34 "idx_type": "int64_t",
35 "idx_abbrev": "l"
36 }
37 ],
38 "_optimized": [
39 {
40 "optimized_name": "optimized",
41 "optimized_value": "true",
42 "veclen": ["1", "4"]
43 },
44 {
45 "optimized_name": "standard",
46 "optimized_value": "false",
47 "veclen": ["8", "16"]
48 }
49 ]
50}

This generates 24 combinations (3 data/output type combinations × 2 index types × 4 optimized/veclen combinations):

  • float + float + uint32_t + optimized + veclen=1
  • float + float + uint32_t + optimized + veclen=4
  • float + float + uint32_t + standard + veclen=8
  • float + float + uint32_t + standard + veclen=16
  • float + double + uint32_t + optimized + veclen=1
  • float + double + uint32_t + optimized + veclen=4
  • float + double + uint32_t + standard + veclen=8
  • float + double + uint32_t + standard + veclen=16
  • __half + float + uint32_t + optimized + veclen=1
  • __half + float + uint32_t + optimized + veclen=4
  • __half + float + uint32_t + standard + veclen=8
  • __half + float + uint32_t + standard + veclen=16
  • … and the same with int64_t (total: 24 combinations)

Step 4: Create .cu.in Template Files

The .cu.in files are templates that get instantiated for each combination in the JSON matrix. They contain explicit template instantiations.

compute_distance_kernel.cu.in

/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#include "@header_file@"
namespace example::detail {
// Instantiate the generic compute_distance device function template
// The specific implementation (euclidean or inner_product) comes from the header
template __device__ float compute_distance<@data_type@>(@data_type@, @data_type@);
} // namespace example::detail

filter_kernel.cu.in

/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#include "example/jit_lto_kernels/@filter_name@.cuh"
namespace example::detail {
// Instantiate the generic apply_filter device function template
// The specific implementation (filter_none or filter_bitset) comes from the header
template __device__ bool apply_filter<@idx_type@>(uint32_t, @idx_type@, void*);
} // namespace example::detail

Update search_kernel.cuh with Extern Declarations

The kernel header needs to declare generic extern device functions so the kernel code can call them. The specific implementations will be linked from fragments at runtime:

search_kernel.cuh:

1#pragma once
2
3#include <cuda_runtime.h>
4
5namespace example::detail {
6
7// Forward declare generic extern device functions that will be linked from fragments
8// The specific implementations (euclidean, inner_product, etc.) are resolved at link time
9template <typename T>
10extern __device__ float compute_distance(T, T);
11
12template <typename IdxT>
13extern __device__ bool apply_filter(uint32_t, IdxT, void*);
14
15// Main kernel - uses generic extern device functions
16template <typename T, typename OutT, typename IdxT, bool UseOptimizedPath, int Veclen>
17__device__ void search_kernel_impl(
18 const T* dataset,
19 const T* queries,
20 IdxT* results,
21 OutT* distances, // Output distance type
22 uint32_t num_queries,
23 uint32_t dataset_size,
24 void* filter_data) {
25
26 uint32_t query_id = blockIdx.x * blockDim.x + threadIdx.x;
27 if (query_id >= num_queries) return;
28
29 OutT best_dist = std::numeric_limits<OutT>::max();
30 IdxT best_idx = 0;
31
32 for (IdxT i = 0; i < dataset_size; ++i) {
33 // Call generic extern device functions (specific implementations linked from fragments)
34 if (!apply_filter<IdxT>(query_id, i, filter_data)) continue;
35
36 OutT dist = static_cast<OutT>(compute_distance<T>(queries[query_id], dataset[i]));
37
38 // Use optimized path if enabled
39 if constexpr (UseOptimizedPath) {
40 // Optimized implementation
41 if (dist < best_dist) {
42 best_dist = dist;
43 best_idx = i;
44 }
45 } else {
46 // Standard implementation
47 if (dist < best_dist) {
48 best_dist = dist;
49 best_idx = i;
50 }
51 }
52 }
53
54 results[query_id] = best_idx;
55 distances[query_id] = best_dist;
56}
57
58} // namespace example::detail

search_kernel.cu.in

The .cu.in file only contains the explicit template instantiation:

/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#include "example/jit_lto_kernels/search_kernel.cuh"
namespace example::detail {
// Instantiate the kernel template
extern "C" __global__ void search_kernel(
const @data_type@* dataset, const @data_type@* queries, @idx_type@* results, @out_type@* distances,
uint32_t num_queries, uint32_t dataset_size, void* filter_data)
{
search_kernel_impl<@data_type@, @out_type@, @idx_type@, @optimized_value@, @veclen@>(
dataset,
queries,
results,
distances,
num_queries,
dataset_size,
filter_data);
}
} // namespace example::detail

Note: The kernel uses generic function templates (compute_distance&lt;T> and apply_filter&lt;IdxT>) that are resolved at link time. The specific implementations (euclidean vs inner_product, filter_none vs filter_bitset) are provided by the fragments that get linked together.

Step 5: Create Fragment Tags for Embedding

Fragment tags register the compiled fatbins so they can be loaded at runtime. They are used to help the linker find and include the relevant fatbins at build time. When calling generate_jit_lto_kernels(), we pass a FRAGMENT_TAG_FORMAT argument, which constructs the tag type from the given placeholders, and a FRAGMENT_TAG_HEADER_FILES argument, which specifies one or more header files that the fragment tags come from. The JIT+LTO system will then automatically generate and compile a .cpp file that registers the fragment using the provided tag.

Important: When requesting fragments from the AlgorithmPlanner, we use tags (like tag_f, tag_h) instead of real types (like float, __half) in the add_static_fragment template parameters. This avoids including heavy headers that define the actual types, significantly improving compilation times. The tags are lightweight empty structs that serve only as compile-time identifiers.

registration_tags.hpp

1#pragma once
2
3struct tag_h{};
4struct tag_f{};
5struct tag_d{};
6struct tag_ui{};
7struct tag_l{};
8
9struct tag_metric_euclidean {};
10struct tag_metric_inner_product {};
11
12struct tag_filter_none {};
13struct tag_filter_bitset {};
14
15template <typename DataTag, typename OutTag, typename IdxTag, bool Optimized, int Veclen>
16struct fragment_tag_search {};
17
18template <typename DistanceTag, typename DataTag>
19struct fragment_tag_compute_distance {};
20
21template <typename FilterTag, typename IndexTag>
22struct fragment_tag_filter {};

Step 6: Create the Planner

The planner is responsible for:

  1. Identifying which fragments are needed for a given configuration
  2. Building a unique key for the fragment combination
  3. Requesting the fragments from the fragment database
  4. Linking them together to create a launchable kernel

CRITICAL: The fragment keys constructed in the planner methods must match EXACTLY with the keys used in the corresponding FRAGMENT_TAG_FORMAT argument. Any mismatch will result in runtime linking failures.

search_planner.hpp:

1#pragma once
2
3#include <cuvs/detail/jit_lto/AlgorithmPlanner.hpp>
4#include <cuvs/detail/jit_lto/FragmentEntry.hpp>
5#include <cuvs/detail/jit_lto/MakeFragmentKey.hpp>
6#include <cuvs/detail/jit_lto/registration_tags.hpp>
7#include <memory>
8#include <string>
9
10struct SearchPlanner : AlgorithmPlanner {
11 inline static LauncherJitCache launcher_jit_cache{};
12
13 SearchPlanner()
14 : AlgorithmPlanner("search_kernel", launcher_jit_cache)
15 {
16 }
17
18 template <typename DataTag, typename OutTag, typename IdxTag, bool Optimized, int Veclen>
19 void add_search_function()
20 {
21 add_static_fragment<fragment_tag_search<DataTag, OutTag, IdxTag, Optimized, Veclen>>();
22 }
23
24 template <typename DistanceTag, typename DataTag>
25 void add_compute_distance_function()
26 {
27 add_static_fragment<fragment_tag_compute_distance<DistanceTag, DataTag>>();
28 }
29
30 template <typename FilterTag, typename IndexTag>
31 void add_filter_function()
32 {
33 add_static_fragment<fragment_tag_filter<FilterTag, IndexTag>>();
34 }
35
36 // Same as add_fragment(std::move(fragment)); distinct names are for readability at call sites.
37 void add_metric_udf_fragment(std::unique_ptr<UDFFatbinFragment> fragment)
38 {
39 add_fragment(std::move(fragment));
40 }
41
42 void add_filter_udf_fragment(std::unique_ptr<UDFFatbinFragment> fragment)
43 {
44 add_fragment(std::move(fragment));
45 }
46};

Step 7: Integrate with Code Path

Now we integrate the planner into the actual search function:

search_jit.cuh:

1#pragma once
2
3#include "search_planner.hpp"
4#include <cuvs/detail/jit_lto/registration_tags.hpp>
5#include <raft/core/device_resources.hpp>
6#include <type_traits>
7
8namespace example::detail {
9
10enum class DistanceType { Euclidean };
11enum class FilterType { None };
12
13// Type tag helpers
14template <typename T>
15constexpr auto get_data_type_tag() {
16 if constexpr (std::is_same_v<T, float>) return tag_f{};
17 if constexpr (std::is_same_v<T, __half>) return tag_h{};
18}
19
20template <typename IdxT>
21constexpr auto get_idx_type_tag() {
22 if constexpr (std::is_same_v<IdxT, uint32_t>) return tag_ui{};
23 if constexpr (std::is_same_v<IdxT, int64_t>) return tag_l{};
24}
25
26template <typename OutType>
27constexpr auto get_out_type_tag() {
28 if constexpr (std::is_same_v<OutType, float>) return tag_f{};
29 if constexpr (std::is_same_v<OutType, double>) return tag_d{};
30}
31
32template <DistanceType Metric>
33constexpr auto get_metric_tag() {
34 if constexpr (Metric == DistanceType::Euclidean) return tag_metric_euclidean{};
35 if constexpr (Metric == DistanceType::InnerProduct) return tag_metric_inner_product{};
36 else static_assert(!sizeof(Metric*), "extend get_metric_tag when adding DistanceType enumerators");
37}
38
39template <FilterType Filter>
40constexpr auto get_filter_tag() {
41 if constexpr (Filter == FilterType::None) return tag_filter_none{};
42 if constexpr (Filter == FilterType::Bitset) return tag_filter_bitset{};
43 else static_assert(!sizeof(Filter*), "extend get_filter_tag when adding FilterType enumerators");
44}
45
46template <typename T, typename OutT, typename IdxT, DistanceType Metric, FilterType Filter, bool Optimized, int Veclen>
47void search_jit(
48 raft::device_resources const& handle,
49 const T* dataset,
50 const T* queries,
51 IdxT* results,
52 OutT* distances,
53 uint32_t num_queries,
54 uint32_t dataset_size,
55 void* filter_data = nullptr) {
56
57 using data_tag = decltype(get_data_type_tag<T>());
58 using idx_tag = decltype(get_idx_type_tag<IdxT>());
59 using out_tag = decltype(get_out_type_tag<OutT>());
60 using metric_tag = decltype(get_metric_tag<Metric>());
61 using filter_tag = decltype(get_filter_tag<Filter>());
62
63 // Create planner with type tags and boolean parameter
64 // Note: The boolean is appended to the fragment key since make_fragment_key
65 // cannot handle non-type template parameters
66 SearchPlanner planner;
67
68 planner.add_search_function<data_tag, out_tag, idx_tag, Optimized, Veclen>();
69 planner.add_compute_distance_device_function<metric_tag, data_tag>();
70 planner.add_filter_device_function<filter_tag, idx_tag>();
71
72 // Get the launcher (this will build/link fragments if needed)
73 auto launcher = planner.get_launcher();
74
75 // Launch configuration
76 dim3 block(256);
77 dim3 grid((num_queries + block.x - 1) / block.x);
78
79 // Launch the kernel - arguments are passed directly
80 launcher->dispatch(
81 raft::resource::get_cuda_stream(handle),
82 grid,
83 block,
84 0, // shared memory size
85 dataset,
86 queries,
87 results,
88 distances,
89 num_queries,
90 dataset_size,
91 filter_data);
92}
93
94} // namespace example::detail

Step 7b: Example — NVRTC UDFs for compute_distance and apply_filter

What you’re building. The same search kernel as Steps 1–7 still calls compute_distance / apply_filter, but for a UDF build those symbols are not taken from prebuilt matrix fatbins: you compile a small NVRTC program per hook at runtime and register it with the planner so LTO links it next to the entry fragment.

How the pieces connect (arrows read left to right):

flowchart LR
subgraph entry["Entry fatbin"]
K["Kernel calls templates"]
H["Header: declare only"]
end
subgraph nvrtc["Per-hook NVRTC TU"]
M["Macro: device body + string factory"]
G["Host glue: forwarding + explicit inst"]
end
subgraph plan["Planner"]
R["add_*_udf_fragment(fatbin)"]
end
K --> H
H --> M
M --> G
G --> R

1. Shared header — forward declarations

The entry TU matches Step 1–7: templates are declared here and defined elsewhere at link time.

1namespace example::detail {
2
3template <typename T>
4__device__ float compute_distance(T q, T d);
5
6template <typename IdxT>
7__device__ bool apply_filter(uint32_t query_id, IdxT node_id, void* filter_data);
8
9} // namespace example::detail

2. NVRTC source — macros and string factories

Use function-like macros so you edit only the { ... } body; the preprocessor still emits a real __device__ template for NVCC, and NAME_udf() / NAME_filter_udf() build the CUDA text NVRTC compiles. The distance macro also emits compute_distance_udf_impl calling NAME_distance; host-side instantiate_compute_distance_udf only appends the forwarding compute_distance plus its explicit instantiation (same idea as instantiate_apply_filter_udf for apply_filter<IdxT>).

Macro definitions (shared header; include before the invocations):

1#include <sstream>
2#include <string>
3
4#define EXAMPLE_PP_CAT_(a, b) a##b
5#define EXAMPLE_PP_CAT(a, b) EXAMPLE_PP_CAT_(a, b)
6#define EXAMPLE_PP_STR_(x) #x
7#define EXAMPLE_PP_STR(x) EXAMPLE_PP_STR_(x)
8
9// NAME_udf(): NVRTC program defines NAME_distance and compute_distance_udf_impl; host appends
10// instantiate_compute_distance_udf (forwarding compute_distance + explicit inst only).
11#define EXAMPLE_UDF_DISTANCE(NAME, BODY) \
12 template <typename T> \
13 __device__ float EXAMPLE_PP_CAT(NAME, _distance)(T q, T d) BODY \
14 \
15 inline std::string EXAMPLE_PP_CAT(NAME, _udf)() \
16 { \
17 return std::string("#include <cuda_runtime.h>\n" \
18 "namespace example::detail {\n" \
19 "template <typename T>\n" \
20 "__device__ float " EXAMPLE_PP_STR(EXAMPLE_PP_CAT(NAME, _distance)) \
21 "(T q, T d) ") \
22 + std::string(#BODY) + \
23 std::string("\n" \
24 "template <typename T>\n" \
25 "__device__ float compute_distance_udf_impl(T q, T d) {\n" \
26 " return ") + \
27 std::string(EXAMPLE_PP_STR(EXAMPLE_PP_CAT(NAME, _distance))) + \
28 std::string("(q, d);\n" \
29 "}\n}\n"); \
30 }
31
32// Forwarding compute_distance + explicit inst only (user metric stays inside NAME_udf()).
33inline std::string instantiate_compute_distance_udf(char const* t_type)
34{
35 std::ostringstream oss;
36 oss << "\nnamespace example::detail {\n"
37 << "template <typename T>\n"
38 << "__device__ float compute_distance(T q, T d) {\n"
39 << " return compute_distance_udf_impl(q, d);\n"
40 << "}\n"
41 << "template __device__ float compute_distance<" << t_type << ">(" << t_type << ", " << t_type
42 << ");\n"
43 << "}\n";
44 return oss.str();
45}
46
47// Device NAME_filter + filter_udf(); append instantiate_apply_filter_udf for apply_filter<IdxT>.
48#define EXAMPLE_UDF_FILTER(NAME, BODY) \
49 template <typename IdxT> \
50 __device__ bool EXAMPLE_PP_CAT(NAME, _filter)(uint32_t query_id, IdxT node_id, void* filter_data) \
51 BODY \
52 \
53 inline std::string EXAMPLE_PP_CAT(NAME, _filter_udf)() \
54 { \
55 return std::string("#include <cuda_runtime.h>\n" \
56 "namespace example::detail {\n" \
57 "template <typename IdxT>\n" \
58 "__device__ bool " EXAMPLE_PP_STR(EXAMPLE_PP_CAT(NAME, _filter)) \
59 "(uint32_t query_id, IdxT node_id, void* filter_data) ") \
60 + std::string(#BODY) + \
61 std::string("\n" \
62 "template <typename IdxT>\n" \
63 "__device__ bool apply_filter(uint32_t query_id, IdxT node_id, " \
64 "void* filter_data) {\n" \
65 " return " EXAMPLE_PP_STR(EXAMPLE_PP_CAT(NAME, _filter)) \
66 "(query_id, node_id, filter_data);\n" \
67 "}\n" \
68 "}\n"); \
69 }
70
71// Call after NAME_filter_udf() string is concatenated, before compile.
72inline std::string instantiate_apply_filter_udf(char const* idx_type)
73{
74 std::ostringstream oss;
75 oss << "\nnamespace example::detail {\n"
76 << "template __device__ bool apply_filter<" << idx_type << ">(uint32_t, " << idx_type
77 << ", void*);\n"
78 << "}\n";
79 return oss.str();
80}

Invocations (file scope — each call expands the macro once):

1EXAMPLE_UDF_DISTANCE(my_l2, {
2 T diff = q - d;
3 return diff * diff;
4})
5
6EXAMPLE_UDF_FILTER(my_pass, {
7 (void)query_id;
8 (void)node_id;
9 (void)filter_data;
10 return true;
11})

Avoid raw " inside #BODY unless you splice that part with raw-string concatenation around #BODY.

3. Host — type_name, glue, compile, register

Each full NVRTC program is one string: *_udf() plus instantiate_* output. type_name<U>() must return the exact token the entry TU uses (e.g. float, uint32_t).

1#include <type_traits>
2
3template <typename U>
4constexpr const char* type_name()
5{
6 using T = std::remove_cv_t<std::remove_reference_t<U>>;
7 if constexpr (std::is_same_v<T, float>) {
8 return "float";
9 } else if constexpr (std::is_same_v<T, uint32_t>) {
10 return "uint32_t";
11 } else {
12 static_assert(std::is_same_v<T, void>, "add a branch for each concrete T / IdxT you use");
13 return "";
14 }
15}

4. Planner — extend Step 7 for UDF vs static

Step 7 used only static fragments. Add #include <string>, extend enums/tags/get_*_tag as below, keep UDF glue in the same TU as search_jit, then swap the two unconditional add_compute_distance_device_function / add_filter_device_function calls for this block:

1// Widen config: static Euclidean vs NVRTC metric, static none vs NVRTC filter.
2// Extend the DistanceType / FilterType enums from Step 7:
3enum class DistanceType { Euclidean, MetricUdf };
4enum class FilterType { None, FilterUdf };
5
6struct tag_metric_custom_udf {};
7struct tag_filter_custom_udf {};
8
9template <DistanceType Metric>
10constexpr auto get_metric_tag() {
11 if constexpr (Metric == DistanceType::Euclidean) {
12 return tag_metric_euclidean{};
13 } else if constexpr (Metric == DistanceType::MetricUdf) {
14 return tag_metric_custom_udf{};
15 } else {
16 static_assert(!sizeof(Metric*), "extend get_metric_tag when adding DistanceType enumerators");
17 }
18}
19
20template <FilterType Filter>
21constexpr auto get_filter_tag() {
22 if constexpr (Filter == FilterType::None) {
23 return tag_filter_none{};
24 } else if constexpr (Filter == FilterType::FilterUdf) {
25 return tag_filter_custom_udf{};
26 } else {
27 static_assert(!sizeof(Filter*), "extend get_filter_tag when adding FilterType enumerators");
28 }
29}
1SearchPlanner planner;
2planner.add_search_function<data_tag, out_tag, idx_tag, Optimized, Veclen>();
3
4// Metric: NVRTC TU vs prebuilt matrix fragment (Step 6 helpers).
5if constexpr (std::is_same_v<metric_tag, tag_metric_custom_udf>) {
6 std::string metric_udf_code = my_l2_udf();
7 metric_udf_code += instantiate_compute_distance_udf(type_name<T>());
8 planner.add_metric_udf_fragment(nvrtc_compiler().compile(metric_udf_code, metric_udf_code));
9} else {
10 planner.add_compute_distance_device_function<metric_tag, data_tag>();
11}
12
13if constexpr (std::is_same_v<filter_tag, tag_filter_custom_udf>) {
14 std::string filter_udf_code = my_pass_filter_udf();
15 filter_udf_code += instantiate_apply_filter_udf(type_name<IdxT>());
16 planner.add_filter_udf_fragment(nvrtc_compiler().compile(filter_udf_code, filter_udf_code));
17} else {
18 planner.add_filter_device_function<filter_tag, idx_tag>();
19}
20
21auto launcher = planner.get_launcher();

Use DistanceType::MetricUdf / FilterType::FilterUdf only when you want the NVRTC branches; otherwise keep Euclidean / None for the original static path.

Pitfalls and constraints

  • Do not register the same hook through both UDF APIs (add_metric_udf_fragment / add_filter_udf_fragment) and the Step 6 static helpers (add_compute_distance_function / add_filter_function): they pull different fatbins and you will duplicate device definitions.
  • The NVRTC program must define every template the entry calls and emit matching template __device__ ... explicit instantiations for each concrete specialization (e.g. compute_distance<float>, apply_filter<uint32_t>). Prefer small host helpers (instantiate_* + type_name) for type spellings instead of hard-coding index types inside macro strings.
  • One NVRTC compile per logical TU; do not concatenate unrelated UDFs into one program string.

Key Concepts

Fragment Tags

Fragment tags uniquely identify fragments. They’re simple lightweight types that are passed as the sole template parameter to StaticFatbinFragmentEntry:

1template <typename OutT>
2struct fragment_tag_get_score {};

Fragment tags may themselves take template parameters in order to uniquely identify them. Typically, one fragment tag template will correspond to a single function, and a fragment tag template specialization will correspond to a function specialization.

When a fatbin is compiled and embedded in C++ code, a translation unit specializes StaticFatbinFragmentEntry to specify its data and length static fields:

1using _FragmentEntry = StaticFatbinFragmentEntry<fragment_tag_get_score<uint32_t>>;
2
3template <>
4const uint8_t* const _FragmentEntry::data = embedded_fatbin;
5
6template <>
7const size_t _FragmentEntry::length = sizeof(embedded_fatbin);

Then, an AlgorithmPlanner can call add_static_fragment() with the fragment tag (NOT the StaticFatbinFragmentEntry specialization) as the sole template parameter:

1template <typename OutTag>
2void add_get_score_function()
3{
4 add_static_fragment<fragment_tag_get_score<OutTag>>();
5}

At build time, the linker takes care of finding and including the static fragments that have been specified by the algorithm planner.

Registration Tags

Registration tags are type-safe identifiers used to organize fragments. They’re typically empty structs:

1struct tag_f {}; // float
2struct tag_h {}; // half
3struct tag_ui {}; // uint32_t
4struct tag_l {}; // int64_t

These tags are used in registerAlgorithm<>() to create a hierarchical organization of fragments.

AlgorithmLauncher

The AlgorithmLauncher is the runtime handle for a linked kernel. It:

  • Holds a cudaKernel_t handle to the linked kernel
  • Provides call() and call_cooperative() methods to launch the kernel
  • Manages the lifetime of the cudaLibrary_t that contains the kernel

Best Practices

  1. Minimize Includes: JIT LTO fragments should have minimal includes, especially avoiding host-side headers. Extract device-only code into separate headers.

  2. Fragment Granularity: Balance between too many small fragments (overhead) and too few large fragments (less reuse). Device functions that are reused across multiple kernels are good candidates for separate fragments.

  3. Naming Consistency: Ensure fragment tags match exactly between registration and lookup. Use helper functions to construct tags consistently.

  4. Type Safety: Use registration tags to provide compile-time type safety and avoid runtime string mismatches.

  5. Caching: Each planner type should hold a static LauncherJitCache and pass it to AlgorithmPlanner; get_launcher() then reuses linked kernels for the same fragment key within that cache.

Example: IVF Flat

IVF Flat uses JIT LTO with:

  • Metric fragments: Euclidean and inner product distance computations (16 fatbins)
  • Post-lambda fragments: Identity, sqrt, and compose post-processing (3 fatbins)
  • Interleaved scan fragments: Main search kernel with various configurations (320 fatbins)
  • Filter fragments: None and bitset filters (2 fatbins)

Total: 341 fatbins that can be combined into many more kernel variants at runtime.

Step 8: Integrate with CMake Build System

To integrate JIT LTO kernels into the CMake build system, add calls to generate_jit_lto_kernels() in your main CMakeLists.txt file (typically in cpp/CMakeLists.txt).

The generate_jit_lto_kernels() function (defined in cmake/modules/generate_jit_lto_kernels.cmake) takes:

  • NAME_FORMAT: Format string for generated kernel names (using @variable@ syntax)
  • MATRIX_JSON_FILE: Path to the JSON matrix file
  • KERNEL_INPUT_FILE: Path to the .cu.in template
  • FRAGMENT_TAG_FORMAT: Format string for fragment tag type (using @variable@ syntax)
  • FRAGMENT_TAG_HEADER_FILES: List of header files that provide the fragment tag types (can be enclosed in </> or "/", automatically enclosed in quotes if quotes and brackets are not provided)
  • OUTPUT_DIRECTORY: Where generated files are placed
  • KERNEL_LINK_LIBRARIES: Interface library with compilation settings

Call generate_jit_lto_kernels() once for each fragment type (compute_distance, filter, search_kernel, etc.). The function reads the JSON matrix, computes the cross-product of all combinations, generates .cu and .cpp files from the templates, compiles them into fatbins, and returns a list of generated source files that should be added to your JIT LTO library target.

See the CUVS cpp/CMakeLists.txt file for a complete example of how to set up the interface library, call generate_jit_lto_kernels() for each fragment type, and create the final library target.

Summary

JIT LTO enables:

  • Reduced binary size: Compile fragments once, combine many ways
  • Faster compilation: Fragments compile independently
  • Runtime flexibility: Link fragments on-demand based on configuration
  • Code reuse: Device function fragments shared across kernels

The process involves:

  1. Separating device functions into fragment headers
  2. Creating JSON matrices defining parameter combinations
  3. Creating .cu.in templates for explicit instantiations
  4. Creating fragment tag types for fatbin registration
  5. Creating a planner to manage fragment dependencies
  6. Integrating the planner into the code path to launch kernels
  7. Adding CMake integration to generate and compile all fragment variants

Fragment Architecture

JIT LTO kernels are split into fragments, which are fatbins containing individual pieces of code that can be strung together rather than instantiating the whole kernel at once. Each fragment only needs to be multiplied out over the dimensions (template parameters) that the fragment itself contains rather than the kernel as a whole. At runtime, these fragments are combined together by nvjitlink into the final program.

In JIT LTO, there are two kinds of code: algorithms and adapters. Algorithms are, roughly speaking, code that actually “does stuff” - searching, sorting, even as simple as initializing variables. Adapters don’t do anything by themselves, but are merely thin wrappers around algorithms that exist only for reducing the number of template parameters that the caller needs to know about. It should generally be assumed that algorithm code is expensive to multiply over a matrix, and thus such multiplication should be minimized, while adapter code is cheap to multiply.

An algorithm function is a function that contains real code for the algorithm, and an adapter function merely calls an algorithm function with more template parameters than the adapter function itself has. An algorithm file contains algorithm code, and an adapter file contains adapter code.

Here is an example of an algorithm file that contains an algorithm function:

1template <typename T, T Divisor>
2__device__ bool is_divisible_impl(T value)
3{
4 return value % Divisor == 0;
5}

Here is an example of an adapter file that contains an adapter function:

#include "device_functions.cuh" // is_divisible
#include "is_divisible_impl.cuh" // is_divisible_impl
namespace {
using data_t = @data_type@;
constexpr data_t divisor = @divisor@;
} // namespace
template <>
__device__ bool is_divisible<data_t>(data_t value)
{
return is_divisible_impl<data_t, divisor>(value);
}

This is the most common pattern that you will see in NVIDIA cuVS’s JIT LTO code. Note that any code that calls is_divisible() does not need to know the value of Divisor, which allows the caller to be multiplied over fewer dimensions, thus reducing the amount of code generated.

Note that in the above adapter file, @data_type@ and @divisor@ are build-time substitutions performed by CMake. These substitutions will be filled in with values from the matrix product. Note that they are all grouped together in a single namespace, making it easy to find all substitutions. This should be preferred to sprinkling the substitutions throughout the code.

Here is an example with two algorithm files:

1// greater_than_impl.cuh
2#include "device_impl_functions.cuh" // filter
3
4template <typename T, T Comparand>
5__device__ bool filter(T value)
6{
7 return value > Comparand;
8}
1// less_than_impl.cuh
2#include "device_impl_functions.cuh" // filter
3
4template <typename T, T Comparand>
5__device__ bool filter(T value)
6{
7 return value < Comparand;
8}

And here is the accompanying adapter file:

#include "@op_name@_impl.cuh" // filter
namespace {
using data_t = @data_type@;
}
template __device__ bool filter<data_t>(data_t value);

This is another common pattern that you will see in NVIDIA cuVS JIT LTO. Note that the adapter file does not contain any adapter functions, but merely instantiates a different algorithm function based on which algorithm file is included based on the CMake substitution.

When a piece of algorithm code is used in multiple kernels, it should be split into its own shared fragment. At this point, it becomes important to also distinguish algorithm fragments and adapter fragments. An algorithm fragment contains an algorithm function that exposes all of the relevant template parameters, and this fragment is shared between multiple kernels. An adapter fragment is specific to a kernel. If a kernel wishes to invoke the same shared algorithm multiple times in the same run with different template parameters, it can employ multiple adapter fragments to accomplish this. Consider the following header file:

1// filter.cuh
2
3template <typename T, T Comparand>
4__device__ bool filter_less_than(T value);
5
6template <typename T, T Comparand>
7__device__ bool filter_greater_than(T value);

And the following adapter files:

#include "device_functions.cuh" // filter_first_pass
#include "filter.cuh" // filter
namespace {
using data_t = @data_type@;
constexpr data_t comparand = @comparand@;
}
template <>
__device__ bool filter_first_pass<data_t>(data_t value)
{
return filter_@op_name@<data_t, comparand>(value);
}
#include "device_functions.cuh" // filter_second_pass
#include "filter.cuh" // filter
namespace {
using data_t = @data_type@;
constexpr data_t comparand = @comparand@;
}
template <>
__device__ bool filter_second_pass<data_t>(data_t value)
{
return filter_@op_name@<data_t, comparand>(value);
}

And the following algorithm file:

1#include "device_functions.cuh" // filter_first_pass, filter_second_pass
2
3template <typename T>
4__device__ bool filter_all_passes(T value)
5{
6 return filter_first_pass<T>(value) && filter_second_pass<T>(value);
7}

Note that filter_first_pass and filter_second_pass both invoke one of the filter functions, but which one they invoke is decided independently for each. Also note that neither of the adapter fragments contains the underlying algorithm code, but rather links against the corresponding shared algorithm fragments.

The key to minimizing code generation is to minimize the number of dimensions that any given fragment needs to be multiplied out over. If a section of algorithm code uses lots of template parameters, try to separate out sections that use only a subset of these parameters, put them into their own fragment, and remove the corresponding template parameters from the caller. Make judicious use of adapter code to accomplish this. An adapter function should only have the template parameters that appear in its signature, whereas an algorithm function should have all of the template parameters that appear in its signature or its implementation.

Unoptimized algorithm:

1#include "filter_less_than.cuh"
2
3template <typename T, T Comparand>
4__device__ size_t find_first(T* values, size_t count)
5{
6 for (size_t i = 0; i < count; i++) {
7 if (filter_less_than_impl<T, Comparand>(values[i])) {
8 return i;
9 }
10 }
11
12 // Could not find any
13 return count;
14}

Note that the algorithm includes the Comparand template parameter, which means the entire algorithm has to be multiplied out over all the possible values of this parameter.

Optimized algorithm:

1#include "device_functions.cuh"
2
3template <typename T>
4__device__ size_t find_first(T* values, size_t count)
5{
6 for (size_t i = 0; i < count; i++) {
7 if (filter_less_than<T>(values[i])) {
8 return i;
9 }
10 }
11
12 // Could not find any
13 return count;
14}

We are now using an adapter function (possibly inside an adapter fragment) called filter_less_than to invoke filter_less_than_impl (which may be inside a shared algorithm fragment). This allows us to hide the Comparand parameter from find_first, which means we no longer need to multiply the entire algorithm over all possible values of Comparand, only the filter_less_than adapter and algorithm.