C++ API Documentation#
The TensorRT C++ API lets developers import, calibrate, generate, and deploy networks using C++. Networks can be imported directly from ONNX, or created programmatically by instantiating individual layers and setting parameters and weights directly.
To view API changes between releases, refer to the TensorRT GitHub repository and use the compare tool.
This section illustrates the basic usage of the C++ API, assuming you start with an ONNX model. The sampleOnnxMNIST sample illustrates this use case in more detail. For the same workflow using the Python API, refer to Python API Documentation.
The C++ API can be accessed through the header NvInfer.h and is in the nvinfer1 namespace. For example, a simple application might begin with:
#include "NvInfer.h"
using namespace nvinfer1;
Interface classes in the TensorRT C++ API begin with the prefix I, such as ILogger and IBuilder.
A CUDA context is automatically created the first time TensorRT calls CUDA if none exists before that point. However, creating and configuring the CUDA context yourself is generally preferable before the first call to TensorRT.
The code in this chapter does not use smart pointers to illustrate object lifetimes; however, their use is recommended with TensorRT interfaces.
The Build Phase#
The build phase uses the builder to optimize a model and produce an engine. To create a builder:
Instantiate the
ILoggerinterface. This example captures all warning messages but ignores informational messages:
class Logger : public ILogger
{
void log(Severity severity, const char* msg) noexcept override
{
// suppress info-level messages
if (severity <= Severity::kWARNING)
std::cout << msg << std::endl;
}
} logger;
Create the builder:
IBuilder* builder = createInferBuilder(logger);
Building engines is intended as an offline process, so it can take significant time. The Reducing Engine Build Time section has tips on making the builder run faster.
Creating a Network Definition#
After the builder has been created, the first step in optimizing a model is to create a network definition:
Specify the network creation options using a combination of flags OR’d together (or
0for none). Note that all networks are strongly typed in TensorRT 11, so you need not set thekSTRONGLY_TYPEDflag (a warning will be emitted if you do). For more information, refer to the Strongly Typed Networks section.Create the network:
INetworkDefinition* network = builder->createNetworkV2(flags);
Creating a Network Definition from Scratch (Advanced)
Instead of using a parser, you can define the network directly to TensorRT through the Network Definition API.
This example creates a simple network with Input, Convolution, Pooling, MatrixMultiply, Shuffle, Activation, and Softmax layers. It also loads the weights into a weightMap data structure, which is used in the following code.
Create the builder and network objects. Note that the logger is initialized using the logger.cpp file common to all C++ samples. The C++ sample helper classes and functions can be found in the common.h header file.
auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(nullptr);
Add the Input layer to the network by specifying the input tensor’s name, datatype, and full dimensions. A network can have multiple inputs, although in this sample, there is only one:
auto data = network->addInput(INPUT_BLOB_NAME, datatype, Dims4{1, 1, INPUT_H, INPUT_W});
Add the Convolution layer with hidden layer input nodes, strides, and weights for filter and bias.
auto conv1 = network->addConvolution(
*data->getOutput(0), 20, DimsHW{5, 5}, weightMap["conv1filter"], weightMap["conv1bias"]);
conv1->setStride(DimsHW{1, 1});
Note
Weights passed to TensorRT layers are in host memory.
Add the Pooling layer; note that the output from the previous layer is passed as input.
auto pool1 = network->addPooling(*conv1->getOutput(0), PoolingType::kMAX, DimsHW{2, 2});
pool1->setStride(DimsHW{2, 2});
Add a Shuffle layer to reshape the input in preparation for matrix multiplication:
int32_t const batch = input->getDimensions().d[0];
int32_t const mmInputs = input.getDimensions().d[1] * input.getDimensions().d[2] * input.getDimensions().d[3];
auto inputReshape = network->addShuffle(*input);
inputReshape->setReshapeDimensions(Dims{2, {batch, mmInputs}});
Add a MatrixMultiply layer. The model exporter provided transposed weights, so the kTRANSPOSE option is specified.
IConstantLayer* filterConst = network->addConstant(Dims{2, {nbOutputs, mmInputs}}, mWeightMap["ip1filter"]);
auto mm = network->addMatrixMultiply(*inputReshape->getOutput(0), MatrixOperation::kNONE, *filterConst->getOutput(0), MatrixOperation::kTRANSPOSE);
Add the bias, which will broadcast across the batch dimension.
auto biasConst = network->addConstant(Dims{2, {1, nbOutputs}}, mWeightMap["ip1bias"]);
auto biasAdd = network->addElementWise(*mm->getOutput(0), *biasConst->getOutput(0), ElementWiseOperation::kSUM);
Add the ReLU Activation layer:
auto relu1 = network->addActivation(*biasAdd->getOutput(0), ActivationType::kRELU);
Add the SoftMax layer to calculate the final probabilities:
auto prob = network->addSoftMax(*relu1->getOutput(0));
Add a name for the output of the SoftMax layer so that the tensor can be bound to a memory buffer at inference time:
prob->getOutput(0)->setName(OUTPUT_BLOB_NAME);
Mark it as the output of the entire network:
network->markOutput(*prob->getOutput(0));
The network representing the MNIST model has now been fully constructed. For instructions on how to build an engine and run an inference with this network, refer to the Building an Engine and Performing Inference sections.
For more information regarding layers, refer to the TensorRT Operator documentation.
Importing a Model Using the ONNX Parser#
Now, the network definition must be populated from the ONNX representation.
The ONNX parser API is in the file NvOnnxParser.h, and the parser is in the nvonnxparser C++ namespace.
Include the ONNX parser API.
#include "NvOnnxParser.h"
using namespace nvonnxparser;
Create an ONNX parser to populate the network.
IParser* parser = createParser(*network, logger);
Read the model file and process any errors.
parser->parseFromFile(modelFile,
static_cast<int32_t>(ILogger::Severity::kWARNING));
for (int32_t i = 0; i < parser->getNbErrors(); ++i)
{
std::cout << parser->getError(i)->desc() << std::endl;
}
An important aspect of a TensorRT network definition is that it contains pointers to model weights, which the builder copies into the optimized engine. Since the network was created using the parser, the parser owns the memory occupied by the weights, so the parser object should not be deleted until after the builder has run.
Importing a Model Using the ONNX Parser with Custom Weights#
The ONNX parser API allows users to provide their own weights, also known as ONNX initializers, in host memory to override any weights found in the model. Instead of parsing the model immediately, use the following sequence:
#include "NvOnnxParser.h"
using namespace nvonnxparser;
IParser* parser = createParser(*network, logger);
Load the model into the parser.
parser->loadModelProto(modelData, modelSize);
Provide the name, pointer to data, and size of the data for the parser to use instead of the one found in the model. You can call this step multiple times to override multiple weights. These pointers must remain in scope until the parser is destroyed.
parser->loadInitializer(name, data, dataSize);
Begin parsing with the user-defined weights.
parser->parseModelProto();
The same idea extends to the IParserRefitter class, and similar APIs can be used to provide custom weights when refitting an engine built from an ONNX model. For more information, refer to the Refitting a Weight-Stripped Engine Directly from ONNX section.
Building an Engine#
The next step is to create a build configuration specifying how TensorRT should optimize the model:
Create a builder configuration object. This interface has many properties that you can set to control how TensorRT optimizes the network.
IBuilderConfig* config = builder->createBuilderConfig();
Set the maximum workspace size. Layer implementations often require a temporary workspace, and this parameter limits the maximum size that any layer in the network can use. If insufficient workspace is provided, TensorRT might not be able to find an implementation for a layer. By default, the workspace is set to the total global memory size of the given device; restrict it when necessary, such as when multiple engines are to be built on a single device.
config->setMemoryPoolLimit(MemoryPoolType::kWORKSPACE, 1U << 20);
Set the maximum shared memory allocation for the CUDA backend implementation. This allocation becomes pivotal in scenarios where TensorRT needs to coexist with other applications, such as when both TensorRT and DirectX concurrently utilize the GPU.
config->setMemoryPoolLimit(MemoryPoolType::kTACTIC_SHARED_MEMORY, 48 << 10);
Build the serialized engine:
IHostMemory* serializedModel = builder->buildSerializedNetwork(*network, *config);
Release objects that are no longer needed. The serialized engine contains the necessary copies of the weights, so the parser, network definition, builder configuration, and builder can be safely deleted:
delete parser;
delete network;
delete config;
delete builder;
Save the engine to disk if needed, then release the buffer that held the serialized engine:
delete serializedModel;
Note
Serialized engines are not cross-platform portable. They are specific to the exact GPU model on which they were built (in addition to the platform).
Building engines is intended as an offline process, so it can take significant time. The Reducing Engine Build Time section has tips on making the builder run faster.
Deserializing a Plan#
To load a previously serialized plan and run inference:
Create a runtime. Like the builder, the runtime requires a logger:
IRuntime* runtime = createInferRuntime(logger);
Warning
Engine files are executable artifacts that contain compiled CUDA tactics. Deserialize only engines you built yourself or received over a trusted channel. Never deserialize an engine file from an untrusted source.
After creating the runtime, deserialize the plan using one of the following approaches:
Deserialize from an in-memory buffer. This method is straightforward and suitable for smaller models or when memory is not a constraint:
std::vector<char> modelData = readModelFromFile("model.plan");
ICudaEngine* engine = runtime->deserializeCudaEngine(modelData.data(), modelData.size());
Deserialize using
IStreamReaderV2. This method supports custom file handling and weight streaming, and can reduce peak memory usage by reading the plan in chunks as needed:
class MyStreamReaderV2 : public IStreamReaderV2 {
// Custom implementation with support for device memory reading
};
MyStreamReaderV2 readerV2("model.plan");
ICudaEngine* engine = runtime->deserializeCudaEngine(readerV2);
The IStreamReaderV2 approach is particularly beneficial for large models or when using advanced features like GPUDirect or weight streaming. It can significantly reduce engine load time and memory usage.
When choosing a deserialization method, consider your specific requirements:
For small models or simple use cases, in-memory deserialization is often sufficient.
For large models or when memory efficiency is crucial, consider using
IStreamReaderV2.If you need custom file handling or weight streaming capabilities,
IStreamReaderV2provides the necessary flexibility.
Performing Inference#
After the engine is loaded, run inference through an execution context:
For production threading boundaries, refer to the Thread-Safety Deny-List. For memory bounding and multi-tenant OOM prevention, refer to Bounding TensorRT Memory in Production. For CUDA error isolation between tenants, refer to Cross-Context CUDA Error Isolation.
Create an execution context:
IExecutionContext *context = engine->createExecutionContext();
An engine can have multiple execution contexts, allowing one set of weights to be used for multiple overlapping inference tasks.
Set device buffer addresses for input and output tensors using
setTensorAddress. Use the tensor names you assigned when building the network:
context->setTensorAddress(INPUT_NAME, inputBuffer);
context->setTensorAddress(OUTPUT_NAME, outputBuffer);
If the engine was built with dynamic shapes, specify the input shapes:
context->setInputShape(INPUT_NAME, inputDims);
Start inference on a CUDA stream using
enqueueV3:
context->enqueueV3(stream);
A network will be executed asynchronously or not, depending on the structure and features of the network. A non-exhaustive list of features that can cause synchronous behavior are data-dependent shapes, DLA usage, loops, and synchronous plugins. It is common to enqueue data transfers with cudaMemcpyAsync() before and after the kernels to move data from the GPU if it is not already there.
To determine when the kernels (and possibly cudaMemcpyAsync()) are complete, use standard CUDA synchronization mechanisms such as events or waiting on the stream.
Complete End-to-End Example#
The snippets above introduce each C++ API class in isolation. The program below stitches them into a single copy-paste-ready file that goes from an ONNX model to a benchmarked engine to a live inference call. Save it as trt_end_to_end.cpp, build with the command line at the bottom, and run it against any ONNX model. For example, the ResNet-50 v1 model referenced earlier in this section.
#include <cassert>
#include <fstream>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include <cuda_runtime.h>
#include "NvInfer.h"
#include "NvOnnxParser.h"
using namespace nvinfer1;
class Logger : public ILogger
{
void log(Severity severity, char const* msg) noexcept override
{
if (severity <= Severity::kWARNING)
{
std::cout << msg << std::endl;
}
}
} gLogger;
// Build a strongly typed TensorRT engine from an ONNX file and save it to disk.
std::vector<char> buildEngine(char const* onnxPath, char const* enginePath)
{
std::unique_ptr<IBuilder> builder{createInferBuilder(gLogger)};
uint32_t const flag = 1U << static_cast<uint32_t>(
NetworkDefinitionCreationFlag::kSTRONGLY_TYPED);
std::unique_ptr<INetworkDefinition> network{builder->createNetworkV2(flag)};
std::unique_ptr<nvonnxparser::IParser> parser{
nvonnxparser::createParser(*network, gLogger)};
if (!parser->parseFromFile(
onnxPath, static_cast<int32_t>(ILogger::Severity::kWARNING)))
{
for (int32_t i = 0; i < parser->getNbErrors(); ++i)
{
std::cerr << parser->getError(i)->desc() << std::endl;
}
throw std::runtime_error("Failed to parse ONNX file");
}
std::unique_ptr<IBuilderConfig> config{builder->createBuilderConfig()};
config->setMemoryPoolLimit(MemoryPoolType::kWORKSPACE, 1ULL << 30); // 1 GiB
std::unique_ptr<IHostMemory> serialized{
builder->buildSerializedNetwork(*network, *config)};
if (!serialized)
{
throw std::runtime_error("Engine build failed");
}
std::ofstream out(enginePath, std::ios::binary);
out.write(static_cast<char const*>(serialized->data()), serialized->size());
return std::vector<char>(
static_cast<char const*>(serialized->data()),
static_cast<char const*>(serialized->data()) + serialized->size());
}
// Deserialize the engine and run a single inference on the supplied input.
std::vector<float> runInference(
std::vector<char> const& engineBytes, std::vector<float> const& hostInput,
Dims const& inputDims)
{
std::unique_ptr<IRuntime> runtime{createInferRuntime(gLogger)};
std::unique_ptr<ICudaEngine> engine{
runtime->deserializeCudaEngine(engineBytes.data(), engineBytes.size())};
std::unique_ptr<IExecutionContext> context{engine->createExecutionContext()};
char const* const inputName = engine->getIOTensorName(0);
char const* const outputName = engine->getIOTensorName(1);
context->setInputShape(inputName, inputDims);
Dims const outputDims = context->getTensorShape(outputName);
size_t outputCount = 1;
for (int32_t i = 0; i < outputDims.nbDims; ++i)
{
outputCount *= outputDims.d[i];
}
std::vector<float> hostOutput(outputCount);
void* dInput{nullptr};
void* dOutput{nullptr};
cudaMalloc(&dInput, hostInput.size() * sizeof(float));
cudaMalloc(&dOutput, hostOutput.size() * sizeof(float));
cudaStream_t stream;
cudaStreamCreate(&stream);
cudaMemcpyAsync(dInput, hostInput.data(), hostInput.size() * sizeof(float),
cudaMemcpyHostToDevice, stream);
context->setTensorAddress(inputName, dInput);
context->setTensorAddress(outputName, dOutput);
if (!context->enqueueV3(stream))
{
throw std::runtime_error("enqueueV3 failed");
}
cudaMemcpyAsync(hostOutput.data(), dOutput,
hostOutput.size() * sizeof(float), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
cudaFree(dInput);
cudaFree(dOutput);
cudaStreamDestroy(stream);
return hostOutput;
}
int main(int argc, char** argv)
{
char const* const onnxPath = (argc > 1) ? argv[1] : "resnet50-v1-12.onnx";
char const* const enginePath = "model.engine";
auto engineBytes = buildEngine(onnxPath, enginePath);
// Replace this with a preprocessed image batch in your own application.
std::vector<float> hostInput(1 * 3 * 224 * 224);
std::mt19937 rng{0};
std::uniform_real_distribution<float> dist{0.0f, 1.0f};
for (auto& v : hostInput) v = dist(rng);
auto hostOutput = runInference(
engineBytes, hostInput, Dims4{1, 3, 224, 224});
size_t topClass = 0;
for (size_t i = 1; i < hostOutput.size(); ++i)
{
if (hostOutput[i] > hostOutput[topClass]) topClass = i;
}
std::cout << "Output size: " << hostOutput.size() << "\n";
std::cout << "Top-1 class index: " << topClass << "\n";
return 0;
}
Build and run, replacing the include and library paths with your TensorRT install location:
g++ -std=c++17 trt_end_to_end.cpp \
-I/usr/include/x86_64-linux-gnu -I/usr/local/cuda/include \
-L/usr/lib/x86_64-linux-gnu -L/usr/local/cuda/lib64 \
-lnvinfer -lnvonnxparser -lcudart \
-o trt_end_to_end
./trt_end_to_end resnet50-v1-12.onnx
On first invocation the program builds and serializes model.engine, then runs one inference on a random tensor sized for ResNet-50.
Once you have a working baseline, swap the random input for a preprocessed image batch, plug the program into your application, and use the focused sections above when you need finer control over network construction, custom weights, alternate deserialization paths, or asynchronous inference scheduling.
Set Internal Library Path API#
TensorRT internally loads builder resource libraries (libnvinfer_builder_resource_*.so) from the system library path (for example, LD_LIBRARY_PATH). If you load TensorRT from a custom location using dlopen, these resource libraries might not be found automatically.
Use nvinfer1::setInternalLibraryPath to specify the directory where TensorRT should look for its internal resource libraries at runtime.
// Load TensorRT from a custom location
dlopen("/path/to/custom/libnvinfer.so");
// Tell TensorRT where to find its internal resource libraries
nvinfer1::setInternalLibraryPath("/path/to/custom");
// Create builder and build engine
Next Steps#
See also
- Optimizing Performance
Benchmarking methodology and best practices for maximizing inference throughput and latency.
- Working with Dynamic Shapes
Building engines that handle variable input dimensions at runtime.
- Accuracy Considerations
Understanding precision trade-offs and mitigating accuracy loss with reduced precision.
- Working with Quantized Types
INT8, FP8, and FP4 quantization workflows including PTQ and QAT.
- Advanced Features
Strongly typed networks, layer precision control, DLA, and custom plugins.