Advanced Topics#

Version Compatibility#

By default, TensorRT engines are compatible only with the version of TensorRT with which they are built. With appropriate build-time configuration, engines can be built that are compatible with later TensorRT versions. TensorRT engines built with TensorRT 8 will also be compatible with TensorRT 9 and TensorRT 10 runtimes, but not vice versa. However, version compatible engines may be slower than engines built for the default runtime.

Version compatibility is supported from version 8.6; the plan must be built with a version at least 8.6 or higher, and the runtime must be 8.6 or higher.

When using version compatibility, the API supported at runtime for an engine is the intersection of the API supported in the version with which it was built and the API of the version used to run it. TensorRT removes APIs only on major version boundaries, so this is not a concern within a major version. However, users wishing to use TensorRT 8 or TensorRT 9 engines with TensorRT 10 must migrate away from removed APIs and are advised to avoid the deprecated APIs.

The recommended approach to creating a version-compatible engine is to build as follows:

1builderConfig.setFlag(BuilderFlag::kVERSION_COMPATIBLE);
2IHostMemory* plan = builder->buildSerializedNetwork(network, config);
1builder_config.set_flag(tensorrt.BuilderFlag.VERSION_COMPATIBLE)
2plan = builder.build_serialized_network(network, config)

The request for a version-compatible engine causes a copy of the lean runtime to be added to the plan. When you deserialize the plan, TensorRT will recognize that it contains a runtime copy. It loads the runtime to deserialize and execute the rest of the plan. Because this results in code being loaded and run from the plan in the context of the owning process, you should only deserialize trusted plans this way. To indicate to TensorRT that you trust the plan, call:

1runtime->setEngineHostCodeAllowed(true);
1runtime.engine_host_code_allowed = True

The flag for trusted plans is also required if you are packaging plugins in the plan. For more information, refer to the Plugin Shared Libraries section.

Manually Loading the Runtime#

The previous approach (Version Compatibility) packages a copy of the runtime with every plan, which can be prohibitive if your application uses many models. An alternative approach is to manage the runtime loading yourself. For this approach, build version-compatible plans as explained in the previous section, but also set an additional flag to exclude the lean runtime.

1builderConfig.setFlag(BuilderFlag::kVERSION_COMPATIBLE);
2builderConfig.setFlag(BuilderFlag::kEXCLUDE_LEAN_RUNTIME);
3IHostMemory* plan = builder->buildSerializedNetwork(network, config);
1builder_config.set_flag(tensorrt.BuilderFlag.VERSION_COMPATIBLE)
2builder_config.set_flag(tensorrt.BuilderFlag.EXCLUDE_LEAN_RUNTIME)
3plan = builder.build_serialized_network(network, config)

To run this plan, you must have access to the lean runtime for the version with which it was built. Suppose you have built the plan with TensorRT 8.6, and your application is linked against TensorRT 10. You can load the plan as follows.

1IRuntime* v10Runtime = createInferRuntime(logger);
2IRuntime* v8ShimRuntime = v10Runtime->loadRuntime(v8RuntimePath);
3engine = v8ShimRuntime->deserializeCudaEngine(v8plan);
1v10_runtime = tensorrt.Runtime(logger)
2v8_shim_runtime = v10_runtime.load_runtime(v8_runtime_path)
3engine = v8_shim_runtime.deserialize_cuda_engine(v8_plan)

The runtime will translate TensorRT 10 API calls for the TensorRT 8.6 runtime, checking to ensure that the call is supported and performing any necessary parameter remapping.

Loading from Storage#

TensorRT can load the shared runtime library directly from memory on most OSs. However, on Linux kernels before 3.17, a temporary directory is required. Use the IRuntime::setTempfileControlFlags and IRuntime::setTemporaryDirectory APIs to control TensorRT’s use of these mechanisms.

Using Version Compatibility with the ONNX Parser#

When building a version-compatible engine from a TensorRT network definition generated using TensorRT’s ONNX parser, you must specify that the parser must use the native InstanceNormalization implementation instead of the plugin one.

To do this, use the IParser::setFlag() API.

1auto *parser = nvonnxparser::createParser(network, logger);
2parser->setFlag(nvonnxparser::OnnxParserFlag::kNATIVE_INSTANCENORM);
1parser = trt.OnnxParser(network, logger)
2parser.set_flag(trt.OnnxParserFlag.NATIVE_INSTANCENORM)

In addition, the parser may require plugins to fully implement all ONNX operators used in the network. In particular, if the network is used to build a version-compatible engine, some plugins may need to be included (either serialized with the engine or provided externally and explicitly loaded).

To query the list of plugin libraries needed to implement a particular parsed network, use the IParser::getUsedVCPluginLibraries API:

1auto *parser = nvonnxparser::createParser(network, logger);
2parser->setFlag(nvonnxparser::OnnxParserFlag::kNATIVE_INSTANCENORM);
3parser->parseFromFile(filename, static_cast<int>(ILogger::Severity::kINFO));
4int64_t nbPluginLibs;
5char const* const* pluginLibs = parser->getUsedVCPluginLibraries(nbPluginLibs);
1parser = trt.OnnxParser(network, logger)
2parser.set_flag(trt.OnnxParserFlag.NATIVE_INSTANCENORM)
3
4status = parser.parse_from_file(filename)
5plugin_libs = parser.get_used_vc_plugin_libraries()

Refer to the Plugin Shared Libraries section for instructions on using the resulting library list to serialize the plugins or package them externally.

Hardware Compatibility#

By default, TensorRT engines are only compatible with the type of device where they were built. With build-time configuration, engines that are compatible with other types of devices can be built. Currently, hardware compatibility is supported only for Ampere and later device architectures and is not supported on NVIDIA DRIVE OS or JetPack.

For example, to build an engine compatible with all Ampere and newer architectures, configure the IBuilderConfig as follows:

config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS);

When building in hardware compatibility mode, TensorRT excludes tactics that are not hardware compatible, such as those that use architecture-specific instructions or require more shared memory than is available on some devices. Thus, a hardware-compatible engine may have lower throughput and/or higher latency than its non-hardware-compatible counterpart. The degree of this performance impact depends on the network architecture and input sizes.

Compatibility Checks#

TensorRT records the major, minor, patch, and build versions of the library used to create the plan in a plan. If these do not match the runtime version used to deserialize the plan, it will fail to deserialize. When using version compatibility, the check will be performed by the lean runtime deserializing the plan data. By default, that lean runtime is included in the plan, and the match is guaranteed to succeed.

TensorRT also records the compute capability (major and minor versions) in the plan and checks it against the GPU on which the plan is being loaded. If they do not match, the plan will fail to deserialize. This ensures that kernels selected during the build phase are present and can run. When using hardware compatibility, the check is relaxed; with HardwareCompatibilityLevel::kAMPERE_PLUS, the check will ensure that the compute capability is greater than or equal to 8.0 rather than checking for an exact match.

TensorRT additionally checks the following properties and will issue a warning if they do not match, except when using hardware compatibility:

  • Global memory bus width

  • L2 cache size

  • Maximum shared memory per block and multiprocessor

  • Texture alignment requirement

  • Number of multiprocessors

  • Whether the GPU device is integrated or discrete

If GPU clock speeds differ between engine serialization and runtime systems, the tactics chosen by the serialization system may not be optimal for the runtime system and may incur some performance degradation.

If it is impossible to build a TensorRT engine for each type of GPU, you can select several GPUs to build engines with and run the engine on different GPUs with the same architecture. For example, among the NVIDIA RTX 40xx GPUs, you can build an engine with RTX 4080 and an engine with RTX 4060. At runtime, you can use the RTX 4080 engine on an RTX 4090 GPU and the 4060 engine on an RTX 4070 GPU. In most cases, the engine will run without functional issues and with only a small performance drop compared to running the engine built with the same GPU.

However, deserialization may only succeed if the engine requires a large amount of device memory and the memory available is smaller than when the engine was built. In this case, it is recommended to build the engine on a smaller GPU or on a larger device with limited compute resources.

The safety runtime can deserialize engines generated in an environment where the major, minor, patch, and build versions of TensorRT do not match exactly in some cases. For more information, refer to the NVIDIA DRIVE OS 6.5 Developer Guide.

Refitting an Engine#

TensorRT can refit an engine with new weights without having to rebuild it. However, the option to do so must be specified when building:

...
config->setFlag(BuilderFlag::kREFIT)
builder->buildSerializedNetwork(network, config);

Later, you can create a Refitter object:

ICudaEngine* engine = ...;
IRefitter* refitter = createInferRefitter(*engine,gLogger)

Then, update the weights. For example, to update a set of weights named Conv Layer Kernel Weights:

Weights newWeights = ...;
refitter->setNamedWeights("Conv Layer Kernel Weight",
                    newWeights);

The new weights should have the same count as the original weights used to build the engine. setNamedWeights returns false if something goes wrong, such as a wrong weights name or a change in the weights count.

You can use INetworkDefinition::setWeightsName() to name weights at build time - the ONNX parser uses this API to associate the weights with the names used in the ONNX model. Otherwise, TensorRT will name the weights internally based on the related layer names and weight roles.

You can also pass GPU weights to the refitter via:

Weights newBiasWeights = ...;
refitter->setNamedWeights("Conv Layer Bias Weight", newBiasWeights, TensorLocation::kDEVICE);

Because of how the engine is optimized, if you change some weights, you might have to supply some other weights, too. The interface can tell you what additional weights must be supplied.

This typically requires two calls to IRefitter::getMissingWeights, first to get the number of weights objects that must be supplied, and second to get their layers and roles.

int32_t const n = refitter->getMissingWeights(0, nullptr);
std::vector<const char*> weightsNames(n);
refitter->getMissingWeights(n, weightslayerNames.data());

You can supply the missing weights in any order:

for (int32_t i = 0; i < n; ++i)
    refitter->setNamedWeights(weightsNames[i], Weights{...});

The set of missing weights returned is complete because supplying only the missing weights does not require more.

Once all the weights have been provided, you can update the engine:

bool success = refitter->refitCudaEngine();
assert(success);

If the refit returns false, check the log for a diagnostic; perhaps the issue is about weights that are still missing. There is also an async version, refitCudaEngineAsync, that can accept a stream parameter.

You can update the weights memory directly and then call refitCudaEngine/ refitCudaEngineAsync in another iteration. If weights pointers need to be changed, call setNamedWeights to override the previous setting. Call unsetNamedWeights to unset previously set weights so that they will not be used in later refitting, and it becomes safe to release these weights.

After refitting is done, you can then delete the refitter:

delete refitter;

The engine behaves like it was built from a network updated with the new weights. After refitting the engine, the previously created execution context can continue to be used.

To view all refittable weights in an engine, use refitter->getAllWeights(...), which is similar to how getMissingWeights was used above.

Weight-Stripping#

When refit is enabled, all the constant weights in the network can be updated after the engine is built. However, refitting the engine with new weights introduces a cost and a potential runtime impact. The inability to constant-fold weights may prevent the builder from performing some optimizations.

This cost is unavoidable when the weights with which the engine will be refitted are unknown at build time. However, in some scenarios, the weights are known. For example, you may use TensorRT as one of multiple back ends to execute an ONNX model and wish to avoid an additional copy of weights in the TensorRT plan.

The weight-stripping build configuration enables this scenario; when enabled, TensorRT enables refit only for constant weights that do not impact the builder’s ability to optimize and produce an engine with the same runtime performance as a non-fittable engine. Those weights are then omitted from the serialized engine, resulting in a small plan file that can be refitted at runtime using the weights from the ONNX model.

The trtexec tool provides the -stripWeights flags for building the weight-stripped engine. For more information, refer to the trtexec section.

The following steps show how to refit the weights for weight-stripped engines. When working with ONNX models, the ONNX parser library can perform the refit automatically. For more information, refer to the Refitting a Weight-Stripped Engine Directly from ONNX section.

  1. Set the corresponding builder flag to enable the weight-stripped build. Here, the kSTRIP_PLAN flag works with either kREFIT or kREFIT_IDENTICAL. It defaults to the latter. The REFIT_IDENTICAL flag instructs the TensorRT builder to optimize under the assumption that the engine will be refitted with weights identical to those provided at build time. The kSTRIP_PLAN flag minimizes plan size by stripping out the refittable weights.

1...
2config->setFlag(BuilderFlag::kSTRIP_PLAN);
3config->setFlag(BuilderFlag::kREFIT_IDENTICAL);
4builder->buildSerializedNetwork(network, config);
1config.flags |= 1 << int(trt.BuilderFlag.STRIP_PLAN)
2config.flags |= 1 << int(trt.BuilderFlag.REFIT_IDENTICAL)
3builder.build_serialized_network(network, config)
  1. After the engine is built, save the plan file and distribute it to the installer.

  2. On the client side, when you launch the network for the first time, update all the weights in the engine. Since all the weights in the engine plan were removed, use the getAllWeights API.

1int32_t const n = refitter->getAllWeights(0, nullptr);
1all_weights = refitter.get_all()
  1. Update the weights one by one.

1for (int32_t i = 0; i < n; ++i)
2    refitter->setNamedWeights(weightsNames[i], Weights{...});
1for name in wts_list:
2    refitter.set_named_weights(name, weights[name])
  1. Save the full engine plan file.

1auto serializationConfig = SampleUniquePtr<nvinfer1::ISerializationConfig>(cudaEngine->createSerializationConfig());
2auto serializationFlag = serializationConfig->getFlags()
3serializationFlag &= ~(1<< static_cast<uint32_t>(nvinfer1::SerializationFlag::kEXCLUDE_WEIGHTS));
4serializationConfig->setFlags(serializationFlag)
5auto hostMemory = SampleUniquePtr<nvinfer1::IHostMemory>(cudaEngine->serializeWithConfig(*serializationConfig));
1serialization_config = engine.create_serialization_config()
2serialization_config.flags &= ~(1 << int(trt.SerializationFlag.EXCLUDE_WEIGHTS))
3binary = engine.serialize_with_config(serialization_config)

The application can now use the new full engine plan file for future inference.

Refitting a Weight-Stripped Engine Directly from ONNX#

When working with weight-stripped engines created from ONNX models, the refit process can be done automatically with the IParserRefitter class from the ONNX parser library. The following steps show how to create the class and run the refit process.

  1. Create your engine as described in Weight-Stripping, and create an IRefitter object.

1IRefitter* refitter = createInferRefitter(*engine, gLogger);
1refitter = trt.Refitter(engine, TRT_LOGGER)
  1. Create an IParserRefitter object.

1IParserRefitter* parserRefitter = createParserRefitter(*refitter, gLogger);
1parser_refitter = trt.OnnxParserRefitter(refitter, TRT_LOGGER)
  1. Call the refitFromFile() function of the IParserRefitter. Ensure that the ONNX model is identical to the one used to create the weight-stripped engine. This function will return true if all the stripped weights are found in the ONNX model; otherwise, it will return false.

1bool result = parserRefitter->refitFromFile(“path_to_onnx_model”);
1result = parser_refitter.refit_from_file(“path_to_onnx_model”)
  1. Call the refit function of the IRefitter to complete the refit process.

1refitSuccess = refitter->refitCudaEngine();
1refit_success = refitter.refit_cuda_engine()

Weight-Stripping Work with Lean Runtime#

Additionally, we can leverage the lean runtime further to reduce the package size for the weight-stripped engine. The lean runtime is the same runtime used in version-compatible engines. The original purpose is to allow you to generate a TensorRT engine with version X and load it with an application built with version Y. The lean runtime library is relatively small, approximately 40 MiB. Therefore, software distributors on top of TensorRT only need to ship the weightless engine along with the 40 MiB lean runtime when the weights are already available on the target customer machine.

The recommended approach to build the engine is as follows:

1builderConfig.setFlag(BuilderFlag::kVERSION_COMPATIBLE);
2builderConfig.setFlag(BuilderFlag::kEXCLUDE_LEAN_RUNTIME);
3builderConfig.setFlag(BuilderFlag::kSTRIP_PLAN);
4IHostMemory* plan = builder->buildSerializedNetwork(network, config);
1builder_config.set_flag(tensorrt.BuilderFlag.VERSION_COMPATIBLE)
2builder_config.set_flag(tensorrt.BuilderFlag.EXCLUDE_LEAN_RUNTIME)
3builder_config.set_flag(tensorrt.BuilderFlag.STRIP_PLAN)
4
5plan = builder.build_serialized_network(network, config)

Load the engine with the shared lean runtime library path:

1runtime->loadRuntime("your_lean_runtime_full_path")
1runtime.load_runtime("your_lean_runtime_full_path")

For more information about the lean runtime, refer to the Version Compatibility section.

Fine Grained Refit Build#

When using the kREFIT builder configuration, all weights are marked as refittable. This is useful when it is difficult to distinguish between trainable and untrainable weights. However, marking all weights as refittable can lead to a performance trade-off. This is because certain optimizations are broken when weights are marked as refittable. For example, in the case of the GELU expression, TensorRT can encode all GELU coefficients in a single CUDA kernel. However, if all coefficients are marked as refittable, TensorRT may no longer be able to fuse the Conv-GELU operations into a single kernel. To address this, we have introduced the fine-grained refit API. This API provides precise control over which weights are marked as refittable, allowing for more efficient optimization.

Here is an example of marking weights as refittable in the INetworkDefinition:

1...
2network->setWeightsName(Weights(weights), "conv1_filter"));
3network->markWeightsRefittable("conv1_filter");
4assert(network->areWeightsMarkedRefittable("conv1_filter"));
1...
2network.set_weights_name(conv_filter, "conv1_filter")
3network.mark_weights_refittable("conv1_filter")
4assert network.are_weights_marked_refittable("conv1_filter")

Later, we need to update the builder configuration like this:

1...
2config->setFlag(BuilderFlag::kREFIT_INDIVIDUAL)
3builder->buildSerializedNetwork(network, config);
1...
2config.set_flag(trt.BuilderFlag.REFIT_INDIVIDUAL)
3builder.build_serialized_network(network, config)

The remaining refit code follows the same steps as refitting all weights workflow.

Stripping Weights with Fine-Grained Refit Build#

The fine-grained refit build also works with the weights stripping flag. To run this, we must enable both builder flags in the code after marking the necessary weights as refittable.

Here is an example:

1...
2config->setFlag(BuilderFlag::kSTRIP_PLAN);
3config->setFlag(BuilderFlag::kREFIT_INDIVIDUAL);
4builder->buildSerializedNetwork(network, config);
1config.flags |= 1 << int(trt.BuilderFlag.STRIP_PLAN)
2config.flags |= 1 << int(trt.BuilderFlag.REFIT_INDIVIDUAL)
3builder.build_serialized_network(network, config)

The remaining refit and inference codes are the same as the Weight-Stripping sections.

Algorithm Selection and Reproducible Builds#

The default behavior of TensorRT’s optimizer is to choose the algorithms that globally minimize the execution time of the engine. It does this by timing each implementation, and sometimes, when implementations have similar timings, system noise may determine which will be chosen on any particular run of the builder. Different implementations will typically use different orders of accumulation of floating point values, and two implementations may use different algorithms or even run at different precisions. Thus, different invocations of the builder will typically not result in engines that return bit-identical results.

Sometimes, it is important to have a deterministic build or recreate an earlier build’s algorithm choices. In the previous version of TensorRT, the above requirements were met by implementing IAlgorithmSelector. In the new version, the editable timing cache is used.

When the engine is being built for the first time, you supply the BuilderFlag::kEDITABLE_TIMING_CACHE flag to TensorRT to enable the editable cache. At the same time, you enable and retain the logs and cache files. The logs will provide the name, key, available tactics, and the selected tactic for each model layer. The cache file will record the decisions made by TensorRT.

Next time the same engine is being built, you supply the same flags to TensorRT and use the interface ITimingCache::update to update the cache. Specifically, select tactics for some layers. Then, pass the cache to TensorRT. In the building process, TensorRT will use the newly assigned tactic. Unlike before, in the new version, only one tactic can be assigned to each layer.

Strongly Typed Networks#

By default, TensorRT autotunes tensor types to generate the fastest engine. This can result in accuracy loss when model accuracy requires a layer to run with higher precision than TensorRT chooses. One approach is to use the ILayer::setPrecision and ILayer::setOutputType APIs to control a layer’s I/O types and, hence, its execution precision. This approach works, but figuring out which layers must be run at high precision to get the best accuracy can be challenging.

An alternative approach is to specify low precision use in the model, such as Automatic mixed precision training or quantization-aware training, and have TensorRT adhere to the precision specifications. TensorRT will still autotune over different data layouts to find an optimal set of kernels for the network.

When you specify to TensorRT that a network is strongly typed, it infers a type for each intermediate and output tensor using the rules in the operator type specification. Inferred types are adhered to while building the engine. As types are not autotuned, an engine built from a strongly typed network can be slower than one where TensorRT chooses tensor types. On the other hand, the build time may improve as fewer kernel alternatives are evaluated.

Strongly typed networks are not supported with DLA.

You can create a strongly typed network as follows:

1IBuilder* builder = ...;
2INetworkDefinition* network = builder->createNetworkV2(1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kSTRONGLY_TYPED)))
1builder = trt.Builder(...)
2builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))

For strongly typed networks, the layer APIs setPrecision and setOutputType are not permitted, nor are the builder precision flags kFP16, kBF16, kFP8, kINT8, kINT4, and kFP4. The builder flag kTF32 is permitted as it controls TF32 Tensor Core usage for FP32 types rather than controlling the use of TF32 data types.

Reduced Precision in Weakly-Typed Networks#

Network-Level Control of Precision#

By default, TensorRT works with 32-bit precision but can also execute operations using 16-bit and 8-bit quantized floating points. Using lower precision requires less memory and enables faster computation.

Reduced precision support depends on your hardware (refer to Hardware and Precision). You can query the builder to check the supported precision support on a platform:

1if (builder->platformHasFastFp16()) {  };
1if builder.platform_has_fp16:

Setting flags in the builder configuration informs TensorRT that it may select lower-precision implementations:

1config->setFlag(BuilderFlag::kFP16);
1config.set_flag(trt.BuilderFlag.FP16)

There are three precision flags: FP16, INT8, and TF32, and they may be enabled independently. TensorRT will still choose a higher-precision kernel if it results in a lower runtime or if no low-precision implementation exists.

When TensorRT chooses a precision for a layer, it automatically converts weights as necessary to run the layer.

While using FP16 and TF32 precisions is relatively straightforward, working with INT8 adds additional complexity. For more information, refer to the Working with Quantized Types section.

Note that even if the precision flags are enabled, the engine’s input/output bindings default to FP32. Refer to the I/O Formats section for information on how to set the data types and formats of the input/output bindings.

Layer-Level Control of Precision#

The builder flags provide permissive, coarse-grained control. However, sometimes, part of a network requires a higher dynamic range or is sensitive to numerical precision. You can constrain the input and output types per layer:

1layer->setPrecision(DataType::kFP16)
1layer.precision = trt.fp16

This provides a preferred type (here, DataType::kFP16) for the inputs and outputs.

You may further set preferred types for the layer’s outputs:

1layer->setOutputType(out_tensor_index, DataType::kFLOAT)
1layer.set_output_type(out_tensor_index, trt.fp32)

The computation will use the same floating-point type as the inputs. Most TensorRT implementations have the same floating-point types for input and output; however, Convolution, Deconvolution, and FullyConnected can support quantized INT8 input and unquantized FP16 or FP32 output, as sometimes working with higher-precision outputs from quantized inputs is necessary to preserve accuracy.

Setting the precision constraint hints to TensorRT that it should select a layer implementation whose inputs and outputs match the preferred types, inserting reformat operations if the outputs of the previous layer and the inputs to the next layer do not match the requested types. Note that TensorRT will only be able to select an implementation with these types if they are also enabled using the flags in the builder configuration.

By default, TensorRT chooses such an implementation only if it results in a higher-performance network. If another implementation is faster, TensorRT will use it and issue a warning. You can override this behavior by preferring the type constraints in the builder configuration.

1config->setFlag(BuilderFlag::kPREFER_PRECISION_CONSTRAINTS)
1config.set_flag(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS)

If the constraints are preferred, TensorRT obeys them unless there is no implementation with the preferred precision constraints, in which case it issues a warning and uses the fastest available implementation.

To change the warning to an error, use OBEY instead of PREFER:

1config->setFlag(BuilderFlag::kOBEY_PRECISION_CONSTRAINTS);
1config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS);

sampleINT8API illustrates the use of reduced precision with these APIs.

Precision constraints are optional - you can query whether a constraint has been set using layer->precisionIsSet() in C++ or layer.precision_is_set in Python. If a precision constraint is not set, the result returned from layer->getPrecision() in C++ or reading the precision attribute in Python is not meaningful. Output type constraints are similarly optional.

If no constraints are set using ILayer::setPrecision or ILayer::setOutputType API, then BuilderFlag::kPREFER_PRECISION_CONSTRAINTS or BuilderFlag::kOBEY_PRECISION_CONSTRAINTS are ignored. A layer can choose from precision or output types based on allowed builder precisions.

Note that the ITensor::setType() API does not set the precision constraint of a tensor unless it is one of the input/output tensors of the network. Also, there is a distinction between layer->setOutputType() and layer->getOutput(i)->setType(). The former is an optional type constraining the implementation TensorRT will choose for a layer. The latter specifies the type of a network’s input/output and is ignored if the tensor is not a network input/output. If they are different, TensorRT will insert a cast to ensure that both specifications are respected. Thus, if you call setOutputType() for a layer that produces a network output, you should generally configure the corresponding network output to have the same type.

TF32#

TensorRT allows the use of TF32 Tensor Cores by default. When computing inner products, such as during convolution or matrix multiplication, TF32 execution does the following:

  • Rounds the FP32 multiplicands to FP16 precision but keeps the FP32 dynamic range.

  • Computes an exact product of the rounded multiplicands.

  • Accumulates the products in an FP32 sum.

TF32 Tensor Cores can speed up networks using FP32, typically with no loss of accuracy. It is more robust than FP16 for models that require an HDR (high dynamic range) for weights or activations.

There is no guarantee that TF32 Tensor Cores are used, and there is no way to force the implementation to use them - TensorRT can fall back to FP32 at any time and always falls back if the platform does not support TF32. However, you can disable their use by clearing the TF32 builder flag.

1config->clearFlag(BuilderFlag::kTF32);
1config.clear_flag(trt.BuilderFlag.TF32)

Setting the environment variable NVIDIA_TF32_OVERRIDE=0 when building an engine disables the use of TF32 despite setting BuilderFlag::kTF32. When set to 0, this environment variable overrides any defaults or programmatic configuration of NVIDIA libraries, so they never accelerate FP32 computations with TF32 Tensor Cores. This is meant to be a debugging tool only, and no code outside NVIDIA libraries should change the behavior based on this environment variable. Any other setting besides 0 is reserved for future use.

Warning

Setting the environment variable NVIDIA_TF32_OVERRIDE to a different value when running the engine can cause unpredictable precision/performance effects. It is best left unset when an engine is run.

Note

Unless your application requires the higher dynamic range provided by TF32, FP16 will be a better solution since it almost always yields faster performance.

BF16#

TensorRT supports the bfloat16 (brain float) floating point format on NVIDIA Ampere and later architectures. Like other precisions, it can be enabled using the corresponding builder flag:

1config->setFlag(BuilderFlag::kBF16);
1config.set_flag(trt.BuilderFlag.BF16)

Note that not all layers support bfloat16. For more information, refer to the TensorRT Operator documentation.

Control of Computational Precision#

Sometimes, it is desirable to control the internal precision of the computation in addition to setting the input and output precisions for an operator. TensorRT selects the computational precision by default based on the layer input type and global performance considerations.

There are two layers where TensorRT provides additional capabilities to control computational precision:

The INormalizationLayer provides a setPrecision method to control the precision of accumulation. By default, to avoid overflow errors, TensorRT accumulates in FP32, even in mixed precision mode, regardless of builder flags. You can use this method to specify FP16 accumulation instead.

For the IMatrixMultiplyLayer, TensorRT, by default, selects accumulation precision based on the input types and performance considerations. However, the accumulation type is guaranteed to have a range at least as great as the input types. When using strongly typed mode, you can enforce FP32 precision for FP16 GEMMs by casting the inputs to FP32. TensorRT recognizes this pattern and fuses the casts with the GEMM, resulting in a single kernel with FP16 inputs and FP32 accumulation.

Creating a Graph for FP32 Accumulation Request

I/O Formats#

TensorRT optimizes a network using many different data formats. To allow efficient data passing between TensorRT and a client application, these underlying data formats are exposed at network I/O boundaries, for Tensors marked as network input or output, and when passing data to and from plugins. For other tensors, TensorRT picks formats that result in the fastest overall execution and may insert reformats to improve performance.

You can assemble an optimal data pipeline by profiling the available I/O formats in combination with the formats most efficient for the operations preceding and following TensorRT.

To specify I/O formats, you specify one or more formats as a bitmask.

The following example sets the input tensor format to TensorFormat::kHWC8. Note that this format only works for DataType::kHALF, so the data type must be set accordingly.

1auto formats = 1U << TensorFormat::kHWC8;
2network->getInput(0)->setAllowedFormats(formats);
3network->getInput(0)->setType(DataType::kHALF);
1formats = 1 << int(tensorrt.TensorFormat.HWC8)
2network.get_input(0).allowed_formats = formats
3network.get_input(0).dtype = tensorrt.DataType.HALF

Note that calling setAllowedFormats() or setType() on a tensor that is not a network input/output has no effect and is ignored by TensorRT.

sampleIOFormats illustrates how to specify I/O formats using C++.

The following table shows the supported formats.

Supported I/O Formats#

Format

kINT32

kFLOAT

kHALF

kINT8

kBOOL

kUINT8

kINT64

BF16

FP8

FP4/INT4

kLINEAR

Only for GPU

Yes

Yes

Yes

Yes

Yes

Yes

Yes

Yes

Yes

kCHW2

No

No

Only for GPU

No

No

No

No

Yes

No

No

kCHW4

No

No

Yes

Yes

No

No

No

Yes

No

No

kHWC8

No

No

Only for GPU

No

No

No

No

Only for GPU

No

No

kCHW16

No

No

Yes

No

No

No

No

No

No

No

kCHW32

No

Only for GPU

Only for GPU

Yes

No

No

No

No

No

No

kDHWC8

No

No

Only for GPU

No

No

No

No

Only for GPU

No

No

kCDHW32

No

No

Only for GPU

Only for GPU

No

No

No

No

No

No

kHWC

No

Only for GPU

No

No

No

Yes

No

No

No

No

kDLA_LINEAR

No

No

Only for DLA

Only for DLA

No

No

No

No

No

No

kDLA_HWC4

No

No

Only for DLA

Only for DLA

No

No

No

No

No

No

kHWC16

No

No

Only for NVIDIA Ampere GPUs and later

Only for GPU

No

No

No

No

Only for GPU

No

kDHWC

No

Only for GPU

No

No

No

No

No

No

No

No

Note that for the vectorized formats, the channel dimension must be zero-padded to the multiple of the vector size. For example, if an input binding has dimensions of [16,3,224,224], kHALF data type, and kHWC8 format, then the actual-required size of the binding buffer would be 16**224*224*sizeof(half) bytes, even though the engine->getBindingDimension() API will return tensor dimensions as [16,3,224,224]. The values in the padded part (that is, where C=3,4,…,7 in this example) must be filled with zeros.

Refer to the Data Format Descriptions section for how the data are laid out in memory for these formats.

Sparsity#

NVIDIA Ampere Architecture GPUs support Structured Sparsity. The weights must have at least 2 zeros in every four-entry vector to use this feature to achieve higher inference performance. For TensorRT, the requirements are:

  • For Convolution, for each output channel and each spatial pixel in the kernel weights, every four input channels must have at least two zeros. In other words, assuming that the kernel weights have the shape [K, C, R, S] and C % 4 == 0, then the requirement is verified using the following algorithm:

    hasSparseWeights = True
    for k in range(0, K):
        for r in range(0, R):
            for s in range(0, S):
                for c_packed in range(0, C // 4):
                    if numpy.count_nonzero(weights[k, c_packed*4:(c_packed+1)*4, r, s]) > 2 :
                        hasSparseWeights = False
    
  • For MatrixMultiply, of which Constant produces an input, every four elements of the reduction axis (K) must have at least two zeros.

Polygraphy (polygraphy inspect sparsity) can detect whether the operation weights in an ONNX model follow the 2:4 structured sparsity pattern.

To enable the sparsity feature, set the kSPARSE_WEIGHTS flag in the builder config and make sure that kFP16 or kINT8 modes are enabled. For example:

1config->setFlag(BuilderFlag::kSPARSE_WEIGHTS);
2config->setFlag(BuilderFlag::kFP16);
3config->setFlag(BuilderFlag::kINT8);
1config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)
2config.set_flag(trt.BuilderFlag.FP16)
3config.set_flag(trt.BuilderFlag.INT8)

At the end of the TensorRT logs, when the TensorRT engine is built, TensorRT reports which layers contain weights that meet the structured sparsity requirement and which layers TensorRT selects tactics that use the structured sparsity. Sometimes, tactics with structured sparsity can be slower than normal, and TensorRT will choose normal tactics. The following output shows an example of TensorRT logs showing information about sparsity:

[03/23/2021-00:14:05] [I] [TRT] (Sparsity) Found 3 layer(s) eligible to use sparse tactics: conv1, conv2, conv3
[03/23/2021-00:14:05] [I] [TRT] (Sparsity) Chose 2 layer(s) using sparse tactics: conv2, conv3

Forcing kernel weights to have structured sparsity patterns can lead to accuracy loss. Refer to the Automatic Sparsity tool in PyTorch section to recover lost accuracy with further fine-tuning.

To measure inference performance with structured sparsity using trtexec, refer to the trtexec section.

Empty Tensors#

TensorRT supports empty tensors. A tensor is an empty tensor if it has one or more dimensions with a length of zero. Zero-length dimensions usually get no special treatment. If a rule works for a dimension of length L for an arbitrary positive value of L, it usually works for L=0, too.

For example, when concatenating two tensors with dimensions [x,y,z] and [x,y,w] along the last axis, the result has dimensions [x,y,z+w], regardless of whether x, y, z, or w is zero.

Implicit broadcast rules remain unchanged since only unit-length dimensions are special for broadcast. For example, given two tensors with dimensions [1,y,z] and [x,1,z], their sum computed by IElementWiseLayer has dimensions [x,y,z], regardless of whether x, y, or z is zero.

If an engine binding is an empty tensor, it still needs a non-null memory address, and different tensors should have different addresses. This is consistent with the C++ rule that every object has a unique address. For example, new float[0] returns a non-null pointer. If using a memory allocator that might return a null pointer for zero bytes, ask for at least one byte instead.

Refer to the TensorRT Operator documentation for any special handling of empty tensors per layer.

Reusing Input Buffers#

TensorRT allows specifying a CUDA event to be signaled once the input buffers are free to be reused. This allows the application to immediately refill the input buffer region for the next inference in parallel with finishing the current inference. For example:

1context->setInputConsumedEvent(&inputReady);
1context.set_input_consumed_event(inputReady)

Engine Inspector#

TensorRT provides the IEngineInspector API to inspect the information inside a TensorRT engine. Call the createEngineInspector() from a deserialized engine to create an engine inspector, and then call getLayerInformation() or getEngineInformation() inspector APIs to get the information of a specific layer in the engine or the entire engine, respectively. You can print out the information of the first layer of the given engine, as well as the overall information of the engine, as follows:

1auto inspector = std::unique_ptr<IEngineInspector>(engine->createEngineInspector());
2inspector->setExecutionContext(context); // OPTIONAL
3std::cout << inspector->getLayerInformation(0, LayerInformationFormat::kJSON); // Print the information of the first layer in the engine.
4std::cout << inspector->getEngineInformation(LayerInformationFormat::kJSON); // Print the information of the entire engine.
1inspector = engine.create_engine_inspector()
2inspector.execution_context = context # OPTIONAL
3print(inspector.get_layer_information(0, LayerInformationFormat.JSON)) # Print the information of the first layer in the engine.
4print(inspector.get_engine_information(LayerInformationFormat.JSON)) # Print the information of the entire engine.

Note that the level of detail in the engine/layer information depends on the ProfilingVerbosity builder config setting when the engine is built. By default, ProfilingVerbosity is set to kLAYER_NAMES_ONLY, so only the layer names will be printed. If ProfilingVerbosity is set to kNONE, then no information will be printed; if it is set to kDETAILED, then detailed information will be printed.

Below are some examples of layer information printed by getLayerInformation() API depending on the ProfilingVerbosity setting:

1"node_of_gpu_0/res4_0_branch2a_1 + node_of_gpu_0/res4_0_branch2a_bn_1 + node_of_gpu_0/res4_0_branch2a_bn_2"
 1{
 2    "Name": "node_of_gpu_0/res4_0_branch2a_1 + node_of_gpu_0/res4_0_branch2a_bn_1 + node_of_gpu_0/res4_0_branch2a_bn_2",
 3    "LayerType": "CaskConvolution",
 4    "Inputs": [
 5    {
 6        "Name": "gpu_0/res3_3_branch2c_bn_3",
 7        "Dimensions": [16,512,28,28],
 8        "Format/Datatype": "Thirty-two wide channel vectorized row major Int8 format."
 9    }],
10    "Outputs": [
11    {
12        "Name": "gpu_0/res4_0_branch2a_bn_2",
13        "Dimensions": [16,256,28,28],
14        "Format/Datatype": "Thirty-two wide channel vectorized row major Int8 format."
15    }],
16    "ParameterType": "Convolution",
17    "Kernel": [1,1],
18    "PaddingMode": "kEXPLICIT_ROUND_DOWN",
19    "PrePadding": [0,0],
20    "PostPadding": [0,0],
21    "Stride": [1,1],
22    "Dilation": [1,1],
23    "OutMaps": 256,
24    "Groups": 1,
25    "Weights": {"Type": "Int8", "Count": 131072},
26    "Bias": {"Type": "Float", "Count": 256},
27    "AllowSparse": 0,
28    "Activation": "RELU",
29    "HasBias": 1,
30    "HasReLU": 1,
31    "TacticName": "sm80_xmma_fprop_implicit_gemm_interleaved_i8i8_i8i32_f32_nchw_vect_c_32kcrs_vect_c_32_nchw_vect_c_32_tilesize256x128x64_stage4_warpsize4x2x1_g1_tensor16x8x32_simple_t1r1s1_epifadd",
32    "TacticValue": "0x11bde0e1d9f2f35d"
33    }

In addition, when the engine is built with dynamic shapes, the dynamic dimensions in the engine information will be shown as -1, and the tensor format information will not be shown because these fields depend on the actual shape at the inference phase. To get the engine information for a specific inference shape, create an IExecutionContext, set all the input dimensions to the desired shapes, and then call inspector->setExecutionContext(context). After the context is set, the inspector will print the engine information for the specific shape set in the context.

The trtexec tool provides the --profilingVerbosity, --dumpLayerInfo, and --exportLayerInfo flags for getting engine information for a given engine. Refer to the trtexec section for more details.

Currently, only binding information and layer information, including the dimensions of the intermediate tensors, precisions, formats, tactic indices, layer types, and layer parameters, are included in the engine information. In future TensorRT versions, more information may be added to the engine inspector output as new keys in the output JSON object. More specifications about the keys and the fields in the inspector output will also be provided.

In addition, some subgraphs are handled by a next-generation graph optimizer that still needs to be integrated with the engine inspector. Therefore, the layer information within these layers has yet to be shown. This will be improved in a future version of TensorRT.

Engine graph visualization with Nsight Deep Learning Designer

When detailed TensorRT engine layer information is exported to a JSON file with the --exportLayerInfo option, the engine’s computation graph may be visualized with Nsight Deep Learning Designer. Open the application, and from the File menu, select Open File, then choose the .trt.json file containing the exported metadata.

Layers in a TensorRT Engine Generated from an Object Detection Network

The Layer Explorer window allows you to search for a particular layer or explore the layers in the network. The Parameter Editor window lets you view the selected layer’s metadata.

Optimizer Callbacks#

The optimizer callback API feature allows you to monitor the progress of the TensorRT build process, for example, to provide user feedback in interactive applications. To enable progress monitoring, create an object that implements the IProgressMonitor interface, then attach it to the IBuilderConfig, for example:

1builderConfig->setProgressMonitor(&monitor);
1context.set_progress_monitor(monitor)

Optimization is divided into hierarchically nested phases, each consisting of several steps. At the start of each phase, the phaseStart() method of IProgressMonitor is called, telling you the phase name and how many steps it has. The stepComplete() function is called when each step completes, and phaseFinish() is called when the phase finishes.

Returning false from stepComplete() cleanly forces the build to terminate early.

Preview Features#

The preview feature API is an extension of IBuilderConfig that allows the gradual introduction of new features to TensorRT. Selected new features are exposed under this API, allowing you to opt in or out. A preview feature remains in preview status for one or two TensorRT release cycles and is then either integrated as a mainstream feature or dropped. When a preview feature is fully integrated into TensorRT, it is no longer controllable through the preview API.

Preview features are defined using a 32-bit PreviewFeature enumeration. The feature name and the TensorRT version concatenate feature identifiers.

<FEATURE_NAME>_XXYY

XX and YY are the major and minor versions of the TensorRT release, respectively, which first introduced the feature. The major and minor versions are specified using two digits with leading-zero padding when necessary.

Suppose the semantics of a preview feature change from one TensorRT release to another. In that case, the older preview feature is deprecated, and the revised feature is assigned a new enumeration value and name.

Deprecated preview features are marked per the deprecation policy.

For more information about the C++ API, refer to nvinfer1::PreviewFeature, IBuilderConfig::setPreviewFeature, and IBuilderConfig::getPreviewFeature.

The Python API has similar semantics using the PreviewFeature enum set_preview_feature and get_preview_feature functions.

Debug Tensors#

The debug tensor feature allows you to inspect intermediate tensors as the network executes. There are a few key differences between using debug tensors and marking all required tensors as outputs:

  1. Marking all tensors as outputs requires you to provide memory to store tensors in advance, while debug tensors can be turned off during runtime if unneeded.

  2. When debug tensors are turned off, the performance impact on the execution of the network is minimized.

  3. For a debug tensor in a loop, values are emitted every time it is written.

To enable this feature, perform the following steps:

  1. Mark the target tensors before the network is compiled.

1networkDefinition->markDebug(&tensor);
1network.mark_debug(tensor)
  1. Define a DebugListener class deriving from IDebugListener and implement the virtual function for processing the tensor.

1virtual void processDebugTensor(
2                    void const* addr,
3                    TensorLocation location,
4                    DataType type,
5                    Dims const& shape,
6                    char const* name,
7                    cudaStream_t stream) = 0;
1process_debug_tensor(self, addr, location, type, shape, name, stream)

When the function is invoked during execution, the debug tensor is passed via the parameters:

location: TensorLocation of the tensor
addr: pointer to buffer
type: data Type of the tensor
shape: shape of the tensor
name: name of the tensor
stream: Cuda stream object

The data will be in linear format.

  1. Attach your listener to IExecutionContext.

1executionContext->setDebugListener(&debugListener);
1execution_context.set_debug_state(tensorName, flag)

Weight Streaming#

The weight streaming feature allows you to offload some weights from device memory to host memory. During network execution, these weights are streamed from the host to the device as needed. This technique can free up device memory, enabling you to run larger models or process larger batch sizes.

To enable this feature, during engine building, create a network with kSTRONGLY_TYPED and set kWEIGHT_STREAMING to builder config:

12builder->createNetworkV2(1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kSTRONGLY_TYPED));
3config->setFlag(BuilderFlag::kWEIGHT_STREAMING);
1builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))
2config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING)

During runtime, deserialization allocates a host buffer to store all the weights instead of uploading them directly to the device. This can increase the host’s peak memory usage. You can use IStreamReaderV2 to deserialize directly from the engine file, avoiding needing a temporary buffer, which helps reduce peak memory usage. IStreamReaderV2 replaces the existing IStreamReader deserialization method.

After deserializing the engine, set the device memory budget for weights by:

12engine->setWeightStreamingBudgetV2(size)
12engine.weight_streaming_budget_v2 = size

The following APIs can help to determine the budget:

  • getStreamableWeightsSize() returns the total size of streamable weights.

  • getWeightStreamingScratchMemorySize() returns the extra scratch memory size for a context when weight streaming is enabled.

  • getDeviceMemorySizeV2() returns the total scratch memory size required by a context. If this API is called before enabling weight streaming by setWeightStreamingBudgetV2(), the return value will not include the extra scratch memory size required by weight streaming, which can be obtained using getWeightStreamingScratchMemorySize(). Otherwise, it will include this extra memory.

Additionally, you can combine information about the current free device memory size, context number, and other allocation needs.

TensorRT can also automatically determine a memory budget by getWeightStreamingAutomaticBudget(). However, due to limited information about the user’s specific memory allocation requirements, this automatically determined budget may be suboptimal and potentially lead to out-of-memory errors.

If the budget set by setWeightStreamingBudgetV2 is larger than the total size of streamable weights obtained by getStreamableWeightsSize(), the budget will be clipped to the total size, effectively disabling weight streaming.

You can query the budget set by getWeightStreamingBudgetV2().

The budget can be adjusted by setting it again when there is no active context for the engine.

After setting the budget, TensorRT will automatically determine which weights to retain on the device memory to maximize the overlap between computation and weight fetching.

Cross-Platform Compatibility#

By default, TensorRT engines can only be executed on the same platform (operating system and CPU architecture) where they were built. With build-time configuration, engines can be built to be compatible with other types of platforms. For example, to build an engine on Linux x86_64 platforms and expect the engine to run on Windows x86_64 platforms, configure the IBuilderConfig as follows:

config->setRuntimePlatform(nvinfer1::RuntimePlatform::kWINDOWS_AMD64);

The cross-platform engine might have performance differences from the natively built engine on the target platform. Additionally, it cannot run on the host platform it was built on.

When building a cross-platform engine that also requires version forward compatibility, kEXCLUDE_LEAN_RUNTIME must be set to exclude the target platform lean runtime.

Tiling Optimization#

Tiling optimization enables cross-kernel tiled inference. This technique leverages on-chip caching for continuous kernels in addition to kernel-level tiling. It can significantly enhance performance on platforms constrained by memory bandwidth.

To activate tiling optimization, perform the following steps:

  1. Set the tiling optimization level. Use the following API to specify the duration TensorRT should dedicate to searching for a more effective tiling solution that could improve performance:

    builderConfig->setTilingOptimizationLevel(level)
    

    The optimization level is set to 0 by default, which means TensorRT will not perform any tiling optimization.

    Increasing the level enables TensorRT to explore various strategies and larger search spaces for enhanced performance. However, note that this may significantly increase the engine build time.

  2. Configure the L2 cache limit for tiling. Use the following API to provide TensorRT with an estimate of the L2 cache resources that can be allocated for the current engine during runtime:

    builderConfig->setL2LimitForTiling()
    

    This API is a hint to tell TensorRT how much L2 cache resources can be considered dedicated to the current TensorRT engine in the runtime. This will help TensorRT apply a better tiling solution for multiple tasks concurrently running on one GPU. Note that the usage of the L2 cache depends on the workload and heuristic; TensorRT may not apply this limit for all layers.

    TensorRT manages the default value.