Python API Documentation#

Attention

The TensorRT Python API enables developers in Python based development environments and those looking to experiment with TensorRT to easily parse models (for example, from ONNX) and generate and run PLAN files.

This section illustrates the basic usage of the Python API, assuming you are starting with an ONNX model. The onnx_resnet50.py sample illustrates this use case in more detail.

The Python API can be accessed through the tensorrt module:

import tensorrt as trt

The Build Phase#

To create a builder, you must first create a logger. The Python bindings include a simple logger implementation that logs all messages preceding a certain severity to stdout. It can be used like so:

logger = trt.Logger(trt.Logger.WARNING)

Alternatively, it is possible to define your implementation of the logger by deriving from the ILogger class:

class MyLogger(trt.ILogger):
    def __init__(self):
        trt.ILogger.__init__(self)

    def log(self, severity, msg):
        pass # Your custom logging implementation here

logger = MyLogger()

You can then create a builder:

builder = trt.Builder(logger)

Building engines is intended as an offline process, so it can take significant time. The Optimizing Builder Performance section has tips on making the builder run faster.

Creating a Network Definition#

After the builder has been created, the first step in optimizing a model is to create a network definition. The network definition options are specified using a combination of flags OR-d together.

You can specify that the network should be considered strongly typed using the NetworkDefinitionCreationFlag.STRONGLY_TYPED flag. For more information, refer to the Strongly Typed Networks section.

Finally, create a network:

network = builder.create_network(flag)

Creating a Network Definition from Scratch (Advanced)#

Instead of using a parser, you can define the network directly to TensorRT via the Network Definition API. This scenario assumes that the per-layer weights are ready in host memory to pass to TensorRT during the network creation.

The code corresponding to this section can be found in network_api_pytorch_mnist.

This example creates a simple network with Input, Convolution, Pooling, MatrixMultiply, Shuffle, Activation, and Softmax layers. This example uses a helper class to hold some of the metadata about the model:

class ModelData(object):
    INPUT_NAME = "data"
    INPUT_SHAPE = (1, 1, 28, 28)
    OUTPUT_NAME = "prob"
    OUTPUT_SIZE = 10
    DTYPE = trt.float32

In this example, the weights are imported from the PyTorch MNIST model.

weights = mnist_model.get_weights()

Create the logger, builder, and network classes.

TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(0)

Next, create the input tensor for the network, specifying the name, datatype, and shape of the tensor.

input_tensor = network.add_input(name=ModelData.INPUT_NAME, dtype=ModelData.DTYPE, shape=ModelData.INPUT_SHAPE)

Add a convolution layer, specifying the inputs, number of output maps, kernel shape, weights, bias, and stride:

conv1_w = weights["conv1.weight"].cpu().numpy()
    conv1_b = weights["conv1.bias"].cpu().numpy()
    conv1 = network.add_convolution_nd(
    input=input_tensor, num_output_maps=20, kernel_shape=(5, 5), kernel=conv1_w, bias=conv1_b
    )
    conv1.stride_nd = (1, 1)

Add a pooling layer, specifying the inputs (the output of the previous convolution layer), pooling type, window size, and stride:

pool1 = network.add_pooling_nd(input=conv1.get_output(0), type=trt.PoolingType.MAX, window_size=(2, 2))
    pool1.stride_nd = trt.Dims2(2, 2)

Add the next pair of convolution and pooling layers:

conv2_w = weights["conv2.weight"].cpu().numpy()
conv2_b = weights["conv2.bias"].cpu().numpy()
conv2 = network.add_convolution_nd(pool1.get_output(0), 50, (5, 5), conv2_w, conv2_b)
conv2.stride_nd = (1, 1)

pool2 = network.add_pooling_nd(conv2.get_output(0), trt.PoolingType.MAX, (2, 2))
pool2.stride_nd = trt.Dims2(2, 2)

Add a Shuffle layer to reshape the input in preparation for matrix multiplication:

batch = input.shape[0]
mm_inputs = np.prod(input.shape[1:])
input_reshape = net.add_shuffle(input)
input_reshape.reshape_dims = trt.Dims2(batch, mm_inputs)

Now, add a MatrixMultiply layer. The model exporter provided transposed weights, so the kTRANSPOSE option is specified.

auto prob = network->addSoftMax(*relu1->getOutput(0));

Add bias, which will broadcast across the batch dimension:

bias_const = net.add_constant(trt.Dims2(1, nbOutputs), weights["fc1.bias"].numpy())
bias_add = net.add_elementwise(mm.get_output(0), bias_const.get_output(0), trt.ElementWiseOperation.SUM)

Add the ReLU activation layer:

relu1 = network.add_activation(input=fc1.get_output(0), type=trt.ActivationType.RELU)

Add the final fully connected layer, and mark the output of this layer as the output of the entire network:

fc2_w = weights['fc2.weight'].numpy()
fc2_b = weights['fc2.bias'].numpy()
fc2 = add_matmul_as_fc(network, relu1.get_output(0), ModelData.OUTPUT_SIZE, fc2_w, fc2_b)

fc2.get_output(0).name = ModelData.OUTPUT_NAME
network.mark_output(tensor=fc2.get_output(0))

The network representing the MNIST model has now been fully constructed. For instructions on how to build an engine and run an inference with this network, refer to the Building an Engine and Performing Inference sections.

For more information regarding layers, refer to the TensorRT Operator documentation.

Importing a Model Using the ONNX Parser#

Now, the network definition must be populated from the ONNX representation. You can create an ONNX parser to populate the network as follows:

parser = trt.OnnxParser(network, logger)

Then, read the model file and process any errors:

success = parser.parse_from_file(model_path)
for idx in range(parser.num_errors):
    print(parser.get_error(idx))

if not success:
    pass # Error handling code here

Building an Engine#

The next step is to create a build configuration specifying how TensorRT should optimize the model:

config = builder.create_builder_config()

This interface has many properties that you can set to control how TensorRT optimizes the network. One important property is the maximum workspace size. Layer implementations often require a temporary workspace, and this parameter limits the maximum size that any layer in the network can use. If insufficient workspace is provided, TensorRT may not be able to find an implementation for a layer. By default, the workspace is set to the total global memory size of the given device; restrict it when necessary, for example, when multiple engines are to be built on a single device.

config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20) # 1 MiB

After the configuration has been specified, the engine can be built and serialized with:

serialized_engine = builder.build_serialized_network(network, config)

It may be useful to save the engine to a file for future use. You can do that like so:

with open(“sample.engine”, “wb”) as f:
    f.write(serialized_engine)

Note

Serialized engines are not portable across platforms. They are specific to the exact GPU model on which they were built (in addition to the platform).

Deserializing a Plan#

When you have a previously serialized optimized model and want to perform inference, you must first create an instance of the Runtime interface. Like the builder, the runtime requires an instance of the logger:

runtime = trt.Runtime(logger)

An engine can be deserialized using the following methods: [1]

In-memory Deserialization

This method is straightforward and suitable for smaller models or when memory isn’t a constraint. Read the plan file into a memory buffer.

with open("model.plan", "rb") as f:
    model_data = f.read()
engine = runtime.deserialize_cuda_engine(model_data)

You can deserialize the engine from a serialized engine object:

serialized_engine = builder.build_serialized_network(network, config)
engine = runtime.deserialize_cuda_engine(serialized_engine)

IStreamReaderV2 Deserialization

This is the most advanced method, supporting reading to both host and device pointers and enabling potential performance improvements. With this approach, reading the entire plan file into a buffer to be deserialized is unnecessary, as IStreamReaderV2 allows reading the file in chunks as needed.

class StreamReaderV2(trt.IStreamReaderV2):
    def __init__(self, bytes):
        trt.IStreamReaderV2.__init__(self)
        self.bytes = bytes
        self.len = len(bytes)
        self.index = 0

    def read(self, size, cudaStreamPtr):
        assert self.index + size <= self.len
        data = self.bytes[self.index:self.index + size]
        self.index += size
        return data

    def seek(self, offset, where):
        if where == SeekPosition.SET:
            self.index = offset
        elif where == SeekPosition.CUR:
            self.index += offset
        elif where == SeekPosition.END:
            self.index = self.len - offset
        else:
            raise ValueError(f"Invalid seek position: {where}")

reader_v2 = MyStreamReaderV2("model.plan")
engine = runtime.deserialize_cuda_engine(reader_v2)

The trt.IStreamReaderV2 method is particularly beneficial for large models or when using advanced features like GPUDirect or weight streaming. It can significantly reduce engine load time and memory usage.

When choosing a deserialization method, consider your specific requirements:

  • For small models or simple use cases, in-memory deserialization is often sufficient.

  • For large models or when memory efficiency is crucial, consider using trt.IStreamReaderV2.

  • If you need custom file handling or streaming capabilities, trt.IStreamReaderV2 provides the necessary flexibility.

Performing Inference#

The engine holds the optimized model, but inference requires an additional state for intermediate activations. This is done via the IExecutionContext interface:

context = engine.create_execution_context()

An engine can have multiple execution contexts, allowing one set of weights to be used for multiple overlapping inference tasks. (A current exception to this is when using dynamic shapes when each optimization profile can only have one execution context unless the preview feature, PROFILE_SHARING_0806, is specified.)

To perform inference, you must specify buffers for inputs and outputs:

context.set_tensor_address(name, ptr)

Several Python packages allow you to allocate memory on the GPU, including, but not limited to, the official CUDA Python bindings, PyTorch, cuPy, and Numba.

After populating the input buffer, you can call TensorRT’s execute_async_v3 method to start inference using a CUDA stream. A network will be executed asynchronously or not, depending on the structure and features of the network. A non-exhaustive list of features that can cause synchronous behavior are data-dependent shapes, DLA usage, loops, and synchronous plugins.

First, create the CUDA stream. If you already have one, use a pointer to it, for example, for PyTorch CUDA streams, torch.cuda.Stream(), you can access the pointer using the cuda_stream property; for Polygraphy CUDA streams, use the ptr attribute; or you can create a stream using CUDA Python binding directly by calling cudaStreamCreate().

Next, start inference:

context.execute_async_v3(buffers, stream_ptr)

It is common to enqueue asynchronous transfers (cudaMemcpyAsync()) before and after the kernels to move data from the GPU if it is not already there.

To determine when inference (and asynchronous transfers) are complete, use the standard CUDA synchronization mechanisms, such as events or waiting on the stream. For example, with PyTorch CUDA streams or Polygraphy CUDA streams, issue stream. To synchronize() with streams created with CUDA Python binding, issue cudaStreamSynchronize(stream).

Footnotes