Using the TensorRT-RTX Runtime API#
In this section, we show how to run inference programmatically in TensorRT-RTX, using either the C++ or the Python API. For simplicity, we use the example of a simple convolutional neural network rather than a more modern neural network architecture. We have also omitted error handling, memory management via smart pointers, and so on.
1// Define a simple logger class
2class Logger : public ILogger {
3 void log(Severity severity, const char* msg) noexcept override {
4 if (severity <= Severity::kWARNING) std::cout << msg << std::endl;
5 }
6} gLogger;
7
8// Build a CUDA engine and save it to an file
9void buildEngine(const char* fileName){
10 // Create builder
11 auto builder = createInferBuilder(gLogger);
12 auto network = builder->createNetworkV2(0U);
13 // Input tensor of shape {1, 1, 3, 3}
14 auto input = network->addInput("input", DataType::kFLOAT, Dims4{1,1,3,3});
15
16 // Prepare convolution weights for 3x3 kernel, 1 input, 1 output channel
17 std::vector<float> weightValues(9, 1.f);
18 std::vector<float> biasValues(1, 0.f);
19 Weights W{DataType::kFLOAT, weightValues.data(), 9};
20 Weights B{DataType::kFLOAT, biasValues.data(), 1};
21
22 // Add convolution layer
23 auto conv = network->addConvolution(*input, 1, DimsHW{3,3}, W, B);
24 conv->setStride(DimsHW{1,1});
25
26 // Mark network output
27 auto output = conv->getOutput(0);
28 output->setName("output");
29 network->markOutput(*output);
30
31 // Build the engine and serialize it in one step
32 builder->setMaxBatchSize(1);
33 auto config = builder->createBuilderConfig();
34 config->setMaxWorkspaceSize(1<<20);
35 auto engine = builder->buildSerializedNetwork(*network, *config);
36
37 // Save the serialized engine to a file
38 std::ofstream out{fileName, std::ios::binary};
39 out.write(engine->data(), engine->size());
40 out.close();
41
42 // Clean up
43 delete engine;
44 delete network;
45 delete config;
46 delete builder;
47}
48
49// Load the engine and perform inference
50void performInference(char const* fileName, const std::vector<float>& inputData){
51 auto runtime = createInferRuntime(gLogger);
52 // Deserialize the engine from file
53 ICudaEngine* engine{};
54 std::ifstream is{fileName, std::ios::binary);
55 is.seekg(0, is.end);
56 size_t nbBytes = is.tellg();
57 is.seekg(0, is.beg);
58 std::vector<uint8_t> data(nbBytes);
59 is.read(&data[0], nbBytes);
60 is.close();
61 engine = runtime->deserializeCudaEngine(&data[0], nbBytes);
62 // Create an execution context for inference
63 auto execContext = engine->createExecutionContext();
64 void* bindings[2];
65 size_t inputSize = inputData.size();
66 size_t outputSize = 1;
67 size_t inputNbBytes = inputSize * sizeof(float);
68 size_t outputNbBytes = outputSize * sizeof(float);
69 cudaMalloc(&bindings[0], inputNbBytes);
70 cudaMalloc(&bindings[1], outputNbBytes);
71 // Populate input binding on device
72 cudaMemcpy(bindings[0], &inputData[0], inputNbBytes, cudaMemcpyHostToDevice);
73 // Here we show synchronous inference for simplicity; use enqueueV3()
74 // for asynchronous execution
75 execContext->executeV2(bindings);
76 // Copy result into host memory space
77 float result;
78 cudaMemcpy(&result, bindings[1], outputNbBytes, cudaMemcpyDeviceToHost);
79 std::cout << "Result = " << result << std::endl;
80 delete execContext;
81 delete engine;
82}
1import tensort_rtx as trt_rtx
2import numpy as np
3import pycuda.driver as cuda
4import pycuda.autoinit
5
6GLOBAL_LOGGER = trt_rtx.Logger(trt_rtx.Logger.WARNING)
7
8def buildEngine(fileName):
9 builder = trt_rtx.Builder(GLOBAL_LOGGER)
10 network = builder.create_network(0)
11 input_t = network.add_input(name="input", dtype=trt_rtx.float32, shape=(1,1,3,3))
12 w = np.full((1,1,3,3), 1., dtype=np.float32)
13 b = np.zeros(1, dtype=np.float32)
14 conv = network.add_convolution_nd(
15 input=input_t, num_output_maps=1, kernel_shape=(3,3), kernel=w, bias=b)
16 conv.stride_nd = (1,1)
17 conv.get_output(0).name = "output"
18 network.mark_output(conv.get_output(0))
19 builder.max_batch_size = 1
20 config = builder.create_builder_config()
21 config.max_workspace_size = 1<<20
22 hostMem = builder.build_serialized_network(network, config)
23 with open(fileName, "wb") as fileOut:
24 fileOut.write(hostMem)
25
26def performInference(fileName, inputData):
27 runtime = trt.Runtime(GLOBAL_LOGGER)
28 with open(fileName, "rb") as f:
29 fileBytes = f.read()
30 buffer = memoryview(fileBytes)
31 engine = runtime.deserialize_cuda_engine(buffer)
32 execContext = engine.create_execution_context()
33 dInput = cuda.mem_alloc(inputData.nbytes)
34 cuda.memcpy_htod(dInput, inputData)
35 outputData = np.empty((1,), dtype=np.float32)
36 dOutput = cuda.mem_alloc(outputData.nbytes)
37 bindings = [int(dInput), int(dOutput)]
38 execContext.execute_v2(bindings)
39 cuda.memcpy_dtoh(outputData, dOutput)
40 cuda.mem_free(dInput)
41 cuda.mem_free(dOutput)
42 return outputData