Migrating from TensorRT 4¶
TensorRT 5.0 included an all new Python API. The python bindings were entirely rewritten, and significant changes and improvements were made. This page highlights some of these changes and outlines the steps you can take to migrate your TensorRT 4.0 Python code to more recent versions of TensorRT.
Legacy Compatibility¶
It is possible to continue using the legacy API by importing the TensorRT legacy package.
Simply replace import tensorrt as trt
(for example) with import tensorrt.legacy as trt
in your scripts.
Eventually, legacy support will be dropped, so it is still advisable to migrate to the new API.
Submodules¶
The new API removes the infer
, parsers
, utils
, and lite
submodules. Instead, all functionality is now included in the top-level tensorrt
module.
Create and Destroy Functions¶
The new API no longer includes create and destroy functions. Constructors are used for object creation and
destruction is instead handling by scoping (that is, when an object goes out of scope,
the garbage collecter can free it). It is also possible to destroy objects manually with
del trt_object
. Note that del
only decrements the reference count of the object
- it is still up to the garbage collecter to free it. If you really want to free memory immediately, you can instead
invoke the object’s __del__
function.
The above changes apply to the following classes:
Legacy API | New API |
G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.WARNING) |
TRT_LOGGER = trt.Logger(trt.Logger.WARNING) |
builder = trt.infer.create_infer_builder(logger) |
builder = trt.Builder(logger) |
runtime = trt.infer.create_infer_runtime(logger) |
runtime = trt.Runtime(logger) |
parser = trt.parsers.caffeparser.create_caffe_parser() |
parser = trt.CaffeParser() |
parser = trt.parsers.uffparser.create_uff_parser() |
parser = trt.UffParser() |
parser.destroy() # Or any other TensorRT object |
del parser # with … as … clauses strongly preferred (see below) |
It is generally advisable to scope objects using with ... as ...
clauses.
TensorRT classes implement __exit__
in such a way that
instances are destroyed immediately. Therefore, unlike the del
operator,
with ... as ...
clauses free memory immediately when a TensorRT object goes out of scope.
For example, building an engine from an ONNX file using with ... as ...
might look something like this:
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
ONNX_MODEL = "mnist.onnx"
def build_engine():
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
# Configure the builder here.
builder.max_workspace_size = 2**30
# Parse the model to create a network.
with open(ONNX_MODEL, 'rb') as model:
parser.parse(model.read())
# Build and return the engine. Note that the builder, network and parser are destroyed when this function returns.
return builder.build_cuda_engine(network)
Data Types¶
Data types have been given short-hand aliases to better align with naming conventions established by Python libraries such as NumPy.
Legacy API | New API |
trt.infer.DataType.FLOAT |
trt.float32 |
trt.infer.DataType.HALF |
trt.float16 |
trt.infer.DataType.INT32 |
trt.int32 |
trt.infer.DataType.INT8 |
trt.int8 |
Getters and Setters¶
Where possible, the new API exposes attributes rather than getters and setters. Importantly, getters that take an argument, and setters that either take multiple arguments or have a return value have been preserved, as these do not correspond directly to a single object attribute.
In general, attribute and function names align with NumPy naming conventions.
That is, shape
is used in place of get_dimensions
, dtype
instead of get_type
,
nbytes
instead of size
, num
instead of nb
and so on.
In most cases, it is enough to simply drop the get_
or set_
prefix:
Legacy API | New API |
builder.set_fp16_mode(True) |
builder.fp16_mode = True |
int8_mode = builder.get_int8_mode() |
int8_mode = builder.int8_mode |
builder.set_max_workspace_size(1 << 20) |
builder.max_workspace_size = 1 << 20 |
In other cases, you can apply the aforementioned rules:
Legacy API | New API |
num_inputs = network.get_nb_inputs() |
num_inputs = network.num_inputs |
input_shape = network.get_input(0).get_dimensions() |
input_shape = network.get_input(0).shape |
input_type = network.get_input(0).get_type() |
input_type = network.get_input(0).dtype |
tensorrt.Dims
, tensorrt.Permutation
Behave Like Iterables¶
The tensorrt.Dims
class, and all its subclasses as well as the tensorrt.Permutation
class are interchangeable with Python iterables (such as tuples and lists).
Where previously you would have written code like this:
# Legacy API
dims = network.get_input(0).get_dimensions().to_DimsCHW() # Need to explicitly convert to DimsCHW.
arr = np.random.randint(5, size=(dims.C(), dims.H(), dims.W())) # Cannot directly use Dims where iterables are expected.
dims2 = engine.get_binding_dimensions(1).to_DimsCHW()
# Compute volume. This will differ between different subclasses of Dims.
elt_count = dims2.vol() # This will not work for DimsHW or DimsNCHW, for example.
# Or equivalently,
elt_count = dims2.C() * dims2.H() * dims2.W()
It is now possible to write:
# New API
dims = network.get_input(0).shape # All Dims subclasses behave like iterables, so we don't care about the exact subclass.
arr = np.random.randint(5, size=dims) # Dims act exactly like any other iterable.
dims2 = engine.get_binding_shape(1)
# Compute volume. Works for any iterable, including lists and tuples.
elt_count = trt.volume(dims2)
# Or equivalently,
elt_count = reduce(lambda x, y: x * y, dims2)
Lightweight tensorrt.Weights
¶
Previously, the tensorrt.Weights
class would perform deep-copies of any buffers used to create weights.
To better align with the C++ API, and for the sake of efficiency, the new bindings no longer create
these deep copies, but instead increment the reference count of the existing buffer.
Therefore, modifying the buffer used to create a tensorrt.Weights
object
will also modify the tensorrt.Weights
object.
Note that the tensorrt.ICudaEngine
will still create its own copies of weights internally.
The above only applies to tensorrt.Weights
created before engine construction (when using the Network API, for example).
tensorrt.Weights
Behave like NumPy Arrays¶
The tensorrt.Weights
class is now interchangeable with NumPy arrays.
More concretely, this means that NumPy arrays can be used directly wherever
tensorrt.Weights
objects are expected - no explicit conversion required.
Furthermore, the API now returns NumPy arrays instead of Weights objects.
tensorrt.CaffeParser
Returns NumPy Arrays¶
The parse_binary_proto
function of the tensorrt.CaffeParser
now returns NumPy arrays instead of BinaryProtoBlob
objects.
enqueue
Is Now execute_async
¶
To better match CUDA naming conventions, tensorrt.IExecutionContext
‘s enqueue
function has been renamed to execute_async
.
Keyword Arguments and Default Parameters¶
All functions and constructors in the new API now use named parameters. Additionally, they include default arguments where possible to better mirror the C++ API.
Where previously, you had to order function parameters correctly:
# Legacy API
context.enqueue(batch_size, bindings, stream.handle, None)
You can now use keyword arguments instead:
# New API
context.execute_async(stream_handle=stream.handle, bindings=bindings, batch_size=batch_size)
Serializing and Deserializing Engines¶
The tensorrt.IHostMemory
class now implements Python’s buffer interface.
This means that serialized engines can be treated like any other Python buffer.
Serializing An Engine¶
# New API
with open("sample.engine", "wb") as f:
f.write(engine.serialize())
Deserializing An Engine¶
# New API
with open("sample.engine", "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
Migrating Plugins¶
Using Pybind11¶
SWIG-based wrappers are incompatible with TensorRT’s pybind11-based bindings. In order to migrate existing plugins, you need to write a pybind11-based wrapper for the plugin. Typically, this involves writing bindings for the plugin itself, as well as the PluginFactory. For more details, refer to the fc_plugin_caffe_mnist sample.
You can refer to the Pybind11 User Guide for detailed information on pybind11 usage.
Using the Plugin Registry¶
Alternatively, it is possible to bypass the need for bindings completely by using the plugin registry. This involves two steps:
- Build a shared object file for the plugin. Make sure the
REGISTER_TENSORRT_PLUGIN
macro is used in the plugin implementation.- In python, load the file created above using
ctypes.CDLL()
At this point, you can add plugins to your network using the tensorrt.IPluginRegistry
which can be retrieved with tensorrt.get_plugin_registry()
. For more details, refer to the uff_custom_plugin sample.