Using the TensorRT-RTX Runtime API#

In this section, we show how to run inference programmatically in TensorRT-RTX, using either the C++ or the Python API. For simplicity, we use the example of a simple convolutional neural network rather than a more modern neural network architecture. We have also omitted error handling, memory management via smart pointers, and so on.

 1// Define a simple logger class
 2class Logger : public ILogger {
 3    void log(Severity severity, const char* msg) noexcept override {
 4        if (severity <= Severity::kWARNING) std::cout << msg << std::endl;
 5    }
 6} gLogger;
 7
 8// Build a CUDA engine and save it to an file
 9void buildEngine(const char* fileName){
10    // Create builder
11    auto builder = createInferBuilder(gLogger);
12    auto network = builder->createNetworkV2(0U);
13    // Input tensor of shape {1, 1, 3, 3}
14    auto input = network->addInput("input", DataType::kFLOAT, Dims4{1,1,3,3});
15
16    // Prepare convolution weights for 3x3 kernel, 1 input, 1 output channel
17    std::vector<float> weightValues(9, 1.f);
18    std::vector<float> biasValues(1, 0.f);
19    Weights W{DataType::kFLOAT, weightValues.data(), 9};
20    Weights B{DataType::kFLOAT, biasValues.data(), 1};
21
22    // Add convolution layer
23    auto conv = network->addConvolution(*input, 1, DimsHW{3,3}, W, B);
24    conv->setStride(DimsHW{1,1});
25
26    // Mark network output
27    auto output = conv->getOutput(0);
28    output->setName("output");
29    network->markOutput(*output);
30
31    // Build the engine and serialize it in one step
32    builder->setMaxBatchSize(1);
33    auto config = builder->createBuilderConfig();
34    config->setMaxWorkspaceSize(1<<20);
35    auto engine = builder->buildSerializedNetwork(*network, *config);
36
37    // Save the serialized engine to a file
38    std::ofstream out{fileName, std::ios::binary};
39    out.write(engine->data(), engine->size());
40    out.close();
41
42    // Clean up
43    delete engine;
44    delete network;
45    delete config;
46    delete builder;
47}
48
49// Load the engine and perform inference
50void performInference(char const* fileName, const std::vector<float>& inputData){
51    auto runtime = createInferRuntime(gLogger);
52    // Deserialize the engine from file
53    ICudaEngine* engine{};
54    std::ifstream is{fileName, std::ios::binary);
55    is.seekg(0, is.end);
56    size_t nbBytes = is.tellg();
57    is.seekg(0, is.beg);
58    std::vector<uint8_t> data(nbBytes);
59    is.read(&data[0], nbBytes);
60    is.close();
61    engine = runtime->deserializeCudaEngine(&data[0], nbBytes);
62    // Create an execution context for inference
63    auto execContext = engine->createExecutionContext();
64    void* bindings[2];
65    size_t inputSize = inputData.size();
66    size_t outputSize = 1;
67    size_t inputNbBytes = inputSize * sizeof(float);
68    size_t outputNbBytes = outputSize * sizeof(float);
69    cudaMalloc(&bindings[0], inputNbBytes);
70    cudaMalloc(&bindings[1], outputNbBytes);
71    // Populate input binding on device
72    cudaMemcpy(bindings[0], &inputData[0], inputNbBytes, cudaMemcpyHostToDevice);
73    // Here we show synchronous inference for simplicity; use enqueueV3()
74    // for asynchronous execution
75    execContext->executeV2(bindings);
76    // Copy result into host memory space
77    float result;
78    cudaMemcpy(&result, bindings[1], outputNbBytes, cudaMemcpyDeviceToHost);
79    std::cout << "Result = " << result << std::endl;
80    delete execContext;
81    delete engine;
82}
 1import tensort_rtx as trt_rtx
 2import numpy as np
 3import pycuda.driver as cuda
 4import pycuda.autoinit
 5
 6GLOBAL_LOGGER = trt_rtx.Logger(trt_rtx.Logger.WARNING)
 7
 8def buildEngine(fileName):
 9    builder = trt_rtx.Builder(GLOBAL_LOGGER)
10    network = builder.create_network(0)
11    input_t = network.add_input(name="input", dtype=trt_rtx.float32, shape=(1,1,3,3))
12    w = np.full((1,1,3,3), 1., dtype=np.float32)
13    b = np.zeros(1, dtype=np.float32)
14    conv = network.add_convolution_nd(
15            input=input_t, num_output_maps=1, kernel_shape=(3,3), kernel=w, bias=b)
16    conv.stride_nd = (1,1)
17    conv.get_output(0).name = "output"
18    network.mark_output(conv.get_output(0))
19    builder.max_batch_size = 1
20    config = builder.create_builder_config()
21    config.max_workspace_size = 1<<20
22    hostMem = builder.build_serialized_network(network, config)
23    with open(fileName, "wb") as fileOut:
24        fileOut.write(hostMem)
25
26def performInference(fileName, inputData):
27    runtime = trt.Runtime(GLOBAL_LOGGER)
28    with open(fileName, "rb") as f:
29        fileBytes = f.read()
30    buffer = memoryview(fileBytes)
31    engine = runtime.deserialize_cuda_engine(buffer)
32    execContext = engine.create_execution_context()
33    dInput = cuda.mem_alloc(inputData.nbytes)
34    cuda.memcpy_htod(dInput, inputData)
35    outputData = np.empty((1,), dtype=np.float32)
36    dOutput = cuda.mem_alloc(outputData.nbytes)
37    bindings = [int(dInput), int(dOutput)]
38    execContext.execute_v2(bindings)
39    cuda.memcpy_dtoh(outputData, dOutput)
40    cuda.mem_free(dInput)
41    cuda.mem_free(dOutput)
42    return outputData