Best Practices#

Performance Benchmarking using trtexec#

This section introduces how to use trtexec, a command-line tool designed for TensorRT performance benchmarking, to get the inference performance measurements of your deep learning models.

If you use the TensorRT NGC container, trtexec is installed at /opt/tensorrt/bin/trtexec.

If you manually installed TensorRT, trtexec is part of the installation.

Alternatively, you can build trtexec from source code using the TensorRT OSS repository.

Performance Benchmarking with an ONNX File#

If your model is already in the ONNX format, the trtexec tool can measure its performance directly. In this example, we will use the ResNet-50 v1 ONNX model from the ONNX model zoo to showcase how to use trtexec to measure its performance.

For example, the trtexec command to measure the performance of ResNet-50 with batch size 4 is:

trtexec --onnx=resnet50-v1-12.onnx --shapes=data:4x3x224x224 --fp16 --noDataTransfers --useCudaGraph --useSpinWait

Where:

  • The --onnx flag specifies the path to the ONNX file.

  • The --shapes flag specifies the input tensor shapes.

  • The --fp16 flag enables FP16 tactics.

  • The other flags have been added to make performance results more stable.

The value for the --shapes flag is in the format of name1:shape1,name2:shape2,... Suppose you do not know the input tensor names and shapes. You can get the information by visualizing the ONNX model using tools like Netron or by running a Polygraphy model inspection.

For example, running polygraphy inspect model resnet50-v1-12.onnx prints out:

[I] Loading model: /home/pohanh/trt/resnet50-v1-12.onnx
[I] ==== ONNX Model ====
    Name: mxnet_converted_model | ONNX Opset: 12
    ---- 1 Graph Input(s) ----
    {data [dtype=float32, shape=('N', 3, 224, 224)]}
    ---- 1 Graph Output(s) ----
    {resnetv17_dense0_fwd [dtype=float32, shape=('N', 1000)]}
    ---- 299 Initializer(s) ----
    ---- 175 Node(s) ----

It shows that the ONNX model has a graph input tensor named data whose shape is ('N', 3, 224, 224), where 'N' represents that the dimension can be dynamic. Therefore, the trtexec flag to specify the input shapes with batch size 4 would be --shapes=data:4x3x224x224.

After running the trtexec command, trtexec will parse your ONNX file, build a TensorRT plan file, measure the performance of this plan file, and then print a performance summary as follows:

[04/25/2024-23:57:45] [I] === Performance summary ===
[04/25/2024-23:57:45] [I] Throughput: 507.399 qps
[04/25/2024-23:57:45] [I] Latency: min = 1.96301 ms, max = 1.97534 ms, mean = 1.96921 ms, median = 1.96917 ms, percentile(90%) = 1.97122 ms, percentile(95%) = 1.97229 ms, percentile(99%) = 1.97424 ms
[04/25/2024-23:57:45] [I] Enqueue Time: min = 0.0032959 ms, max = 0.0340576 ms, mean = 0.00421173 ms, median = 0.00415039 ms, percentile(90%) = 0.00463867 ms, percentile(95%) = 0.00476074 ms, percentile(99%) = 0.0057373 ms
[04/25/2024-23:57:45] [I] H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
[04/25/2024-23:57:45] [I] GPU Compute Time: min = 1.96301 ms, max = 1.97534 ms, mean = 1.96921 ms, median = 1.96917 ms, percentile(90%) = 1.97122 ms, percentile(95%) = 1.97229 ms, percentile(99%) = 1.97424 ms
[04/25/2024-23:57:45] [I] D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
[04/25/2024-23:57:45] [I] Total Host Walltime: 3.00355 s
[04/25/2024-23:57:45] [I] Total GPU Compute Time: 3.00108 s
[04/25/2024-23:57:45] [I] Explanations of the performance metrics are printed in the verbose logs.

It prints many performance metrics, but the most important are Throughput and median Latency. In this case, the ResNet-50 model with batch size 4 can run with a throughput of 507 inferences per second (2028 images per second since the batch size is 4) and a median latency of 1.969 ms.

Refer to the Advanced Performance Measurement Techniques section for explanations about what Throughput and Latency mean to your deep learning inference applications. Refer to the trtexec section for detailed explanations about other trtexec flags and other performance metrics that trtexec reports.

Performance Benchmarking with ONNX+Quantization#

To enjoy the additional performance benefit from quantizations, Quantize/Dequantize operations need to be inserted into the ONNX model to tell TensorRT where to quantize/dequantize the tensors and what scaling factors to use.

Our recommended tool for ONNX quantization is the ModelOptimizer package. You can install it by running:

pip3 install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-modelopt

Using the ModelOptimizer, you can get a quantized ONNX model by running:

python3 -m modelopt.onnx.quantization --onnx_path resnet50-v1-12.onnx --quantize_mode int8 --output_path resnet50-v1-12-quantized.onnx

It loads the original ONNX model from resnet50-v1-12.onnx, runs calibration using random data, inserts Quantize/Dequantize ops into the graph, and then saves the ONNX model with Quantize/Dequantize ops to resnet50-v1-12-quantized.onnx.

Now that the new ONNX model contains the INT8 Quantize/Dequantize ops, we can run trtexec again using a similar command:

trtexec --onnx=resnet50-v1-12-quantized.onnx --shapes=data:4x3x224x224 --stronglyTyped --noDataTransfers --useCudaGraph --useSpinWait

We use the --stronglyTyped flag instead of the --fp16 flag to require TensorRT to strictly follow the data types in the quantized ONNX model, including all the INT8 Quantize/Dequantize ops.

Here is an example output after running this trtexec command with the quantized ONNX model:

[04/26/2024-00:31:43] [I] === Performance summary ===
[04/26/2024-00:31:43] [I] Throughput: 811.74 qps
[04/26/2024-00:31:43] [I] Latency: min = 1.22559 ms, max = 1.23608 ms, mean = 1.2303 ms, median = 1.22998 ms, percentile(90%) = 1.23193 ms, percentile(95%) = 1.23291 ms, percentile(99%) = 1.23395 ms
[04/26/2024-00:31:43] [I] Enqueue Time: min = 0.00354004 ms, max = 0.00997925 ms, mean = 0.00431524 ms, median = 0.00439453 ms, percentile(90%) = 0.00463867 ms, percentile(95%) = 0.00476074 ms, percentile(99%) = 0.00512695 ms
[04/26/2024-00:31:43] [I] H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
[04/26/2024-00:31:43] [I] GPU Compute Time: min = 1.22559 ms, max = 1.23608 ms, mean = 1.2303 ms, median = 1.22998 ms, percentile(90%) = 1.23193 ms, percentile(95%) = 1.23291 ms, percentile(99%) = 1.23395 ms
[04/26/2024-00:31:43] [I] D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
[04/26/2024-00:31:43] [I] Total Host Walltime: 3.00219 s
[04/26/2024-00:31:43] [I] Total GPU Compute Time: 2.99824 s
[04/26/2024-00:31:43] [I] Explanations of the performance metrics are printed in the verbose logs.

The Throughput is 811 inferences per second, and the median Latency is 1.23 ms. The Throughput has improved by 60% compared to the FP16 performance results in the previous section.

Per-Layer Runtime and Layer Information#

In previous sections, we described using trtexec to measure the end-to-end latency. This section will show an example of per-layer runtime and per-layer information using trtexec. This will help you determine how much latency each layer contributes to the end-to-end latency and in which layers the performance bottlenecks are.

This is an example trtexec command to print per-layer runtime and per-layer information using the quantized ResNet-50 ONNX model:

trtexec --onnx=resnet50-v1-12-quantized.onnx --shapes=data:4x3x224x224 --stronglyTyped --noDataTransfers --useCudaGraph --useSpinWait --profilingVerbosity=detailed --dumpLayerInfo --dumpProfile --separateProfileRun

The --profilingVerbosity=detailed flag enables detailed layer information capturing, --dumpLayerInfo flag shows the per-layer information in the log, and --dumpProfile --separateProfileRun flags show the per-layer runtime latencies in the log.

The following code is an example log of the per-layer information for one of the convolution layers in the quantized ResNet-50 model:

Name: resnetv17_stage1_conv0_weight + resnetv17_stage1_conv0_weight_QuantizeLinear + resnetv17_stage1_conv0_fwd, LayerType: CaskConvolution, Inputs: [ { Name: resnetv17_pool0_fwd_QuantizeLinear_Output_1, Location: Device, Dimensions: [4,64,56,56], Format/Datatype: Thirty-two wide channel vectorized row major Int8 format }], Outputs: [ { Name: resnetv17_stage1_relu0_fwd_QuantizeLinear_Output, Location: Device, Dimensions: [4,64,56,56], Format/Datatype: Thirty-two wide channel vectorized row major Int8 format }], ParameterType: Convolution, Kernel: [1,1], PaddingMode: kEXPLICIT_ROUND_DOWN, PrePadding: [0,0], PostPadding: [0,0], Stride: [1,1], Dilation: [1,1], OutMaps: 64, Groups: 1, Weights: {"Type": "Int8", "Count": 4096}, Bias: {"Type": "Float", "Count": 64}, HasBias: 1, HasReLU: 1, HasSparseWeights: 0, HasDynamicFilter: 0, HasDynamicBias: 0, HasResidual: 0, ConvXAsActInputIdx: -1, BiasAsActInputIdx: -1, ResAsActInputIdx: -1, Activation: RELU, TacticName: sm80_xmma_fprop_implicit_gemm_interleaved_i8i8_i8i32_f32_nchw_vect_c_32kcrs_vect_c_32_nchw_vect_c_32_tilesize96x64x64_stage3_warpsize2x2x1_g1_tensor16x8x32_simple_t1r1s1, TacticValue: 0x483ad1560c6e5e27, StreamId: 0, Metadata: [ONNX Layer: resnetv17_stage1_conv0_fwd]

The log shows the layer name, the input and output tensor names, tensor shapes, tensor data types, convolution parameters, tactic names, and metadata. The Metadata field shows which ONNX ops this layer corresponds to. Since TensorRT has graph fusion optimizations, one engine layer may correspond to multiple ONNX ops in the original model.

The following code is an example log of the per-layer runtime latencies for the last few layers in the quantized ResNet-50 model:

[04/26/2024-00:42:55] [I]    Time(ms)     Avg.(ms)   Median(ms)   Time(%)   Layer
[04/26/2024-00:42:55] [I]       56.57       0.0255       0.0256       1.8   resnetv17_stage4_conv7_weight + resnetv17_stage4_conv7_weight_QuantizeLinear + resnetv17_stage4_conv7_fwd
[04/26/2024-00:42:55] [I]      103.86       0.0468       0.0471       3.3   resnetv17_stage4_conv8_weight + resnetv17_stage4_conv8_weight_QuantizeLinear + resnetv17_stage4_conv8_fwd
[04/26/2024-00:42:55] [I]       46.93       0.0211       0.0215       1.5   resnetv17_stage4_conv9_weight + resnetv17_stage4_conv9_weight_QuantizeLinear + resnetv17_stage4_conv9_fwd + resnetv17_stage4__plus2 + resnetv17_stage4_activation2
[04/26/2024-00:42:55] [I]       34.64       0.0156       0.0154       1.1   resnetv17_pool1_fwd
[04/26/2024-00:42:55] [I]       63.21       0.0285       0.0287       2.0   resnetv17_dense0_weight + resnetv17_dense0_weight_QuantizeLinear + transpose_before_resnetv17_dense0_fwd + resnetv17_dense0_fwd + resnetv17_dense0_bias + ONNXTRT_Broadcast + unsqueeze_node_after_resnetv17_dense0_bias + ONNXTRT_Broadcast_ONNXTRT_Broadcast_output + (Unnamed Layer* 851) [ElementWise]
[04/26/2024-00:42:55] [I]     3142.40       1.4149       1.4162     100.0   Total

It shows that the median latency of the resnetv17_pool1_fwd layer is 0.0154 ms and contributes to 1.1% of the end-to-end latency. With this log, you can identify which layers take the largest portion of the end-to-end latency and is the performance bottleneck.

The Total latency reported in the per-layer runtime log is the summation of the per-layer latencies. It is typically slightly longer than the reported end-to-end latency due to the overheads caused by measuring per-layer latencies. For example, the Total median latency is 1.4162 ms, but the end-to-end latency shown in the previous section is 1.23 ms.

Performance Benchmarking with TensorRT Plan File#

If you construct the TensorRT INetworkDefinition using TensorRT APIs and build the plan file in a separate script, you can still use trtexec to measure the plan file’s performance.

For example, if the plan file is saved as resnet50-v1-12-quantized.plan, then you can run the trtexec command to measure the performance using this plan file:

trtexec --loadEngine=resnet50-v1-12-quantized.plan --shapes=data:4x3x224x224 --noDataTransfers --useCudaGraph --useSpinWait

The performance summary output is similar to those in the previous sections.

Duration and Number of Iterations#

By default, trtexec warms up for at least 200 ms and runs inference for at least 10 iterations or at least 3 seconds, whichever is longer. You can modify these parameters by adding the --warmUp=500, --iterations=100, and --duration=60 flags, which mean running the warm-up for at least 500 ms and running the inference for at least 100 iterations or at least 60 seconds, whichever is longer.

Refer to the trtexec section or run trtexec --help for a detailed explanation about other trtexec flags.

Advanced Performance Measurement Techniques#

Before starting any optimization effort with TensorRT, it is essential to determine what should be measured. Without measurements, it is impossible to make reliable progress or measure whether success has been achieved.

Latency: A performance measurement for network inference is how much time elapses from an input presented to the network until an output is available. This is the latency of the network for a single inference. Lower latencies are better. In some applications, low latency is a critical safety requirement. In other applications, latency is directly visible to users as a quality-of-service issue. For bulk processing, latency may not be important.

Throughput: Another performance measurement is how many inferences can be completed in a fixed time. This is the throughput of the network. Higher throughput is better. Higher throughputs indicate a more efficient utilization of fixed compute resources. For bulk processing, the total time taken will be determined by the network’s throughput.

Another way of looking at latency and throughput is to fix the maximum latency and measure throughput at that latency. A quality-of-service measurement like this can be a reasonable compromise between the user experience and system efficiency.

Before measuring latency and throughput, you must choose the exact points to start and stop timing. Different points make sense depending on the network and application.

In many applications, there is a processing pipeline, and the latency and throughput of the entire pipeline can measure the overall system performance. Because the pre-and post-processing steps depend so strongly on the particular application, this section considers the latency and throughput of the network inference only.

Wall-Clock Timing#

The wall clock time (the elapsed time between the start of a computation and its end) can be useful for measuring the application’s overall throughput and latency and placing inference times in context within a larger system. C++11 provides high-precision timers in the <chrono> standard library. For example, std::chrono::system_clock represents system-wide wall-clock time, and std::chrono::high_resolution_clock measures time at the highest precision available.

The following example code snippet shows measuring network inference host time:

1#include <chrono>
2
3auto startTime = std::chrono::high_resolution_clock::now();
4context->enqueueV3(stream);
5cudaStreamSynchronize(stream);
6auto endTime = std::chrono::high_resolution_clock::now();
7float totalTime = std::chrono::duration<float, std::milli>
8    (endTime - startTime).count();
1import time
2from cuda import cudart
3err, stream = cudart.cudaStreamCreate()
4start_time = time.time()
5context.execute_async_v3(stream)
6cudart.cudaStreamSynchronize(stream)
7total_time = time.time() - start_time

If there is only one inference happening on the device at one time, then this can be a simple way of profiling the time-various operations take. Inference is typically asynchronous, so ensure you add an explicit CUDA stream or device synchronization to wait for results to become available.

CUDA Events#

One problem with timing on the host exclusively is that it requires host/device synchronization. Optimized applications may have many inferences running parallel on the device with overlapping data movement. In addition, the synchronization adds some noise to timing measurements.

To help with these issues, CUDA provides an Event API. This API allows you to place events into CUDA streams that the GPU will time-stamp as they are encountered. Differences in timestamps can then tell you how long different operations took.

The following example code snippet shows computing the time between two CUDA events:

 1cudaEvent_t start, end;
 2cudaEventCreate(&start);
 3cudaEventCreate(&end);
 4
 5cudaEventRecord(start, stream);
 6context->(enqueueV3stream);
 7cudaEventRecord(end, stream);
 8
 9cudaEventSynchronize(end);
10float totalTime;
11cudaEventElapsedTime(&totalTime, start, end);
1from cuda import cudart
2err, stream = cudart.cudaStreamCreate()
3err, start = cudart.cudaEventCreate()
4err, end = cudart.cudaEventCreate()
5cudart.cudaEventRecord(start, stream)
6context.execute_async_v3(stream)
7cudart.cudaEventRecord(end, stream)
8cudart.cudaEventSynchronize(end)
9err, total_time = cudart.cudaEventElapsedTime(start, end)

Built-In TensorRT Profiling#

Digging deeper into inference performance requires more fine-grained timing measurements within the optimized network.

TensorRT has a Profiler (C++, Python) interface, which you can implement to have TensorRT pass profiling information to your application. When called, the network will run in a profiling mode. After finishing the inference, the profiler object of your class is called to report the timing for each layer in the network. These timings can be used to locate bottlenecks, compare different versions of a serialized engine, and debug performance issues.

The profiling information can be collected from a regular inference enqueueV3() launch or a CUDA graph launch. Refer to IExecutionContext::setProfiler() and IExecutionContext::reportToProfiler() (C++, Python) for more information.

Layers inside a loop are compiled into a single monolithic layer; therefore, separate timings for those layers are unavailable. Also, some subgraphs (especially with Transformer-like networks) are handled by a next-generation graph optimizer that has not yet been integrated with the Profiler APIs. For those networks, use the CUDA Profiling Tools to profile per-layer performance.

An example showing how to use the IProfiler interface is provided in the common sample code (common.h).

Given an input network or plan file, you can use trtexec to profile a network with TensorRT. For more information, refer to the trtexec section.

ONNX Profiling Tools#

Nsight Deep Learning Designer is an integrated design environment for ONNX models. It is built on top of TensorRT. Its built-in profiler runs inference for an ONNX model through TensorRT and collects profiling data based on GPU performance metrics. The profiler report generated by Nsight Deep Learning Designer provides a comprehensive view of an ONNX model’s inference performance at the TensorRT layer level. Its GUI also helps developers correlate the performance of individual TensorRT layers with their originating ONNX operators.

Nsight Deep Learning Designer profiling typically begins with the GUI. Open the Nsight Deep Learning Designer application and click Start Activity from the Welcome screen. Select the target platform type from the list, and you may additionally define a remote connection if you wish to profile on a Linux or L4T target from a remote machine. Select Profile TensorRT Model as the activity type.

Start Activity Dialog

Profiler activity settings typically have analogs in trtexec and are split across four tabs in the GUI. Refer to the Nsight Deep Learning Designer documentation for details of each setting. The most frequently used settings are listed here:

Networks using dynamic shapes (Working with Dynamic Shapes) should specify an optimization profile before profiling. This can be done by editing the ONNX network within Nsight Deep Learning Designer, profiling from the command line, or (for compatible networks) setting the Inferred Batch option in the Optimizer tab. When a batch size is provided, input shapes with a single leading wildcard will be automatically populated with the batch size. This feature works with input shapes of arbitrary rank.

To start the Nsight Deep Learning Designer profiler, click the Launch button. The tool will automatically deploy TensorRT and CUDA runtime libraries to the target as needed and then generate a profiling report:

Profiling Report Overview

Nsight Deep Learning Designer includes a command-line profiler; refer to the tool documentation for usage instructions.

Understanding Nsight Deep Learning Designer Timeline View

Sample Timeline View Showing Layer Execution and GPU Activity

In the Nsight Deep Learning Designer Timeline View, each network inference stream is shown as a row of the timeline alongside collected GPU metrics such as SM activity and PCIe throughput. Each layer executed on an inference stream is depicted as a range on the corresponding timeline view. Overhead sources such as tensor memory copies or reformats are highlighted in blue.

Understanding Nsight Deep Learning Designer Layer Table

Sample Profiling Report Layer Table

The Network Metrics table view shows all TensorRT layers executed by the network, their type, dimensions, precision, and inference time. Layer inference times are provided in the table as both raw time measurements and percentages of the inference pass. You can filter the table by name. Hyperlinks in the table indicate where a layer name references nodes in the original ONNX source model. Click the hyperlink or use the drop-down menu in a selected layer’s Name column to open the original ONNX model and highlight the layer in its context.

Pop-up Showing Originating and Related ONNX Nodes for a TensorRT Layer

Selecting a range of layers in the table aggregates their statistics into a higher-level summary. Each table column is represented in the summary area with the most common values observed within the selection, sorted by frequency. Hover the mouse cursor over an information icon to see the full list of values and associated frequencies. Inference time columns are shown as minimum, maximum, mean, and total, using absolute times, and inference pass percentages as units. Total times in this context can be used to quickly sum the inference time for layers within a single execution stream. Layers in the selection do not need to be contiguous.

Sample Aggregated Statistics for a Selection Range

Understanding Nsight Deep Learning Designer Network Graphs

Sample Average Inference Latency Graph for a Simple Network

Nsight Deep Learning Designer shows the average inference latency for each type of layer in the TensorRT engine. This can highlight areas where the network spends significant time on non-critical computations.

Example Graph Showing Tensor Precisions for Layers in a Network

Nsight Deep Learning Designer also shows the precisions used for each type of layer in the TensorRT engine. This can highlight potential opportunities for network quantization and visualize the effect of setting TensorRT’s tactic precision flags.

CUDA Profiling Tools#

The recommended CUDA profiler is NVIDIA Nsight Systems. Some CUDA developers may be more familiar with nvprof and nvvp. However, these are being deprecated. These profilers can be used on any CUDA program to report timing information about the kernels launched during execution, data movement between host and device, and CUDA API calls used.

Nsight Systems can be configured to report timing information for only a portion of the program’s execution or to report traditional CPU sampling profile information and GPU information.

The basic usage of Nsight Systems is first to run the command nsys profile -o <OUTPUT> <INFERENCE_COMMAND>, then open the generated <OUTPUT>.nsys-rep file in the Nsight Systems GUI to visualize the captured profiling results.

Profile Only the Inference Phase

When profiling a TensorRT application, you should enable profiling only after the engine has been built. During the build phase, all possible tactics are tried and timed. Profiling this portion of the execution will not show any meaningful performance measurements and will include all possible kernels, not the ones selected for inference. One way to limit the scope of profiling is to:

  • First phase: Structure the application to build and then serialize the engines in one phase.

  • Second phase: Load the serialized engines, run inference in a second phase, and profile this second phase only.

Suppose the application cannot serialize the engines or must run through the two phases consecutively. In that case, you can also add cudaProfilerStart() and cudaProfilerStop() CUDA APIs around the second phase and add the -c cudaProfilerApi flag to the Nsight Systems command to profile only the part between cudaProfilerStart() and cudaProfilerStop().

Understand Nsight Systems Timeline View

In the Nsight Systems Timeline View, the GPU activities are shown in the rows under CUDA HW, and the CPU activities are shown in the rows under Threads. By default, the rows under CUDA HW are collapsed. Therefore, you must click on it to expand the rows.

In a typical inference workflow, the application calls the context->enqueueV3() or context->executeV3() APIs to enqueue the jobs and then synchronize on the stream to wait until the GPU completes the jobs. If you only look at the CPU activities, it may appear that the system is doing nothing for a while in the cudaStreamSychronize() call. The GPU may be busy executing the enqueued jobs while the CPU waits. The following figure shows an example timeline of the inference of a query.

The trtexec tool uses a slightly more complicated approach to enqueue the jobs. It enqueues the next query while the GPU still executes the jobs from the previous query. For more information, refer to the trtexec section.

The following image shows a typical view of the normal inference workloads in the Nsight Systems timeline view, showing CPU and GPU activities on different rows.

Normal Inference Workloads in Nsight Systems Timeline View

Use the NVTX Tracing in Nsight Systems

Tracing enables the NVIDIA Tools Extension SDK (NVTX), a C-based API for marking events and ranges in your applications. It allows Nsight Compute and Nsight Systems to collect data generated by TensorRT applications.

Decoding the kernel names back to layers in the original network can be complicated. Because of this, TensorRT uses NVTX to mark a range for each layer, allowing the CUDA profilers to correlate each layer with the kernels called to implement it. In TensorRT, NVTX helps to correlate the runtime engine layer execution with CUDA kernel calls. Nsight Systems supports collecting and visualizing these events and ranges on the timeline. Nsight Compute also supports collecting and displaying the state of all active NVTX domains and ranges in a given thread when the application is suspended.

In TensorRT, each layer may launch one or more kernels to perform its operations. The exact kernels launched depend on the optimized network and the hardware present. Depending on the builder’s choices, multiple additional operations that reorder data may be interspersed with layer computations; these reformat operations may be implemented as device-to-device memory copies or custom kernels.

For example, the following screenshots are from Nsight Systems.

The Layer Execution and the Kernel Being Launched on the CPU Side

The kernels run on the GPU; in other words, the following image shows the correlation between the layer execution and kernel launch on the CPU side and their execution on the GPU side.

The Kernels Run on the GPU

Control the Level of Details in NVTX Tracing

By default, TensorRT only shows layer names in the NVTX markers. At the same time, users can control the level of details by setting the ProfilingVerbosity in the IBuilderConfig when the engine is built. For example, to disable NVTX tracing, set the ProfilingVerbosity to kNONE:

1builderConfig->setProfilingVerbosity(ProfilingVerbosity::kNONE);
1builder_config.profilling_verbosity = trt.ProfilingVerbosity.NONE

On the other hand, you can choose to allow TensorRT to print more detailed layer information in the NVTX markers, including input and output dimensions, operations, parameters, tactic numbers, and so on, by setting the ProfilingVerbosity to kDETAILED:

1builderConfig->setProfilingVerbosity(ProfilingVerbosity::kDETAILED);
1builder_config.profilling_verbosity = trt.ProfilingVerbosity.DETAILED

Note

Enabling detailed NVTX markers increases the latency of enqueueV3() calls and could result in a performance drop if the performance depends on the latency of enqueueV3() calls.

Run Nsight Systems with trtexec

Below is an example of the commands to gather Nsight Systems profiles using the trtexec tool:

trtexec --onnx=foo.onnx --profilingVerbosity=detailed --saveEngine=foo.plan
nsys profile -o foo_profile --capture-range cudaProfilerApi trtexec --profilingVerbosity=detailed --loadEngine=foo.plan --warmUp=0 --duration=0 --iterations=50

The first command builds and serializes the engine to foo.plan, and the second command runs the inference using foo.plan and generates a foo_profile.nsys-rep file can then be opened in the Nsight Systems user interface for visualization.

The --profilingVerbosity=detailed flag allows TensorRT to show more detailed layer information in the NVTX marking, and the --warmUp=0, --duration=0, and --iterations=50 flags allow you to control how many inference iterations to run. By default, trtexec runs inference for three seconds, which may result in a large output of the nsys-rep file.

If the CUDA graph is enabled, add --cuda-graph-trace=node flag to the nsys command to see the per-kernel runtime information:

nsys profile -o foo_profile --capture-range cudaProfilerApi --cuda-graph-trace=node trtexec --profilingVerbosity=detailed --loadEngine=foo.plan --warmUp=0 --duration=0 --iterations=50 --useCudaGraph

(Optional) Enable GPU Metrics Sampling in Nsight Systems

On discrete GPU systems, add the --gpu-metrics-device all flag to the nsys command to sample GPU metrics, including GPU clock frequencies, DRAM bandwidth, Tensor Core utilization, and so on. If the flag is added, these GPU metrics appear in the Nsight Systems web interface.

Profiling for DLA#

To profile DLA, add the --accelerator-trace nvmedia flag when using the NVIDIA Nsight Systems CLI or enable Collect other accelerators trace when using the user interface. For example, the following command can be used with the NVIDIA Nsight Systems CLI:

nsys profile -t cuda,nvtx,nvmedia,osrt --accelerator-trace=nvmedia  --show-output=true trtexec --loadEngine=alexnet_int8.plan --warmUp=0 --duration=0 --iterations=20

Here is an example report:

  • NvMediaDLASubmit submits a DLA task for each DLA subgraph. The task’s runtime can be found in the DLA timeline under Other accelerators trace.

  • Because GPU fallback was allowed, TensorRT automatically added some CUDA kernels, like permutationKernelPLC3 and copyPackedKernel, which are used for data reformatting.

  • EGLStream APIs were executed because TensorRT uses EGLStream for data transfer between GPU memory and DLA.

To maximize GPU utilization, trtexec enqueues the queries one batch beforehand.

Sample DLA Profiling Report

The runtime of the DLA task can be found under Other Accelerator API. Some CUDA kernels and EGLStream API are called for interaction between GPU and DLA.

Sample DLA Profiling Report

Tracking Memory#

Tracking memory usage can be as important as execution performance. Usually, the device’s memory is more constrained than the host’s. To keep track of device memory, the recommended mechanism is to create a simple custom GPU allocator that internally keeps some statistics and then uses the regular CUDA memory allocation functions cudaMalloc and cudaFree.

A custom GPU allocator can be set for the builder IBuilder for network optimizations and IRuntime when deserializing engines using the IGpuAllocator APIs. One idea for the custom allocator is to keep track of the current amount of memory allocated and push an allocation event with a timestamp and other information onto a global list of allocation events. Looking through the list of allocation events allows profiling memory usage over time.

On mobile platforms, GPU memory and CPU memory share the system memory. On devices with very limited memory size, like Nano, system memory might run out with large networks; even the required GPU memory is smaller than system memory. In this case, increasing the system swap size could solve some problems. An example script is:

echo "######alloc swap######"
if [ ! -e /swapfile ];then
    sudo fallocate -l 4G /swapfile
    sudo chmod 600 /swapfile
    sudo mkswap /swapfile
    sudo /bin/sh -c 'echo  "/swapfile \t none \t swap \t defaults \t 0 \t 0" >> /etc/fstab'
    sudo swapon -a
fi

Hardware/Software Environment for Performance Measurements#

Performance measurements are influenced by many factors, including hardware environment differences like the machine’s cooling capability and software environment differences like GPU clock settings. This section summarizes a few items that may affect performance measurements.

Note that the items involving nvidia-smi are only supported on dGPU systems, not mobile ones.

GPU Information Query and GPU Monitoring#

While measuring performance, it is recommended that you record and monitor the GPU status in parallel to the inference workload. Having the monitoring data allows you to identify possible root causes when you see unexpected performance measurement results.

Before the inference starts, call the nvidia-smi -q command to get detailed information on the GPU, including the product name, power cap, clock settings, etc. Then, while the inference workload is running, run the nvidia-smi dmon -s pcu -f <FILE> -c <COUNT> command in parallel to print out GPU clock frequencies, power consumption, temperature, and utilization to a file. Call nvidia-smi dmon --help for more options about the nvidia-smi device monitoring tool.

GPU Clock Locking and Floating Clock#

By default, the GPU clock frequency is floating, meaning it sits idle when there is no active workload and boosts the boost clock frequency when the workload starts. This is usually the desired behavior since it allows the GPU to generate less heat at idle and to run at maximum speed when there is an active workload.

Alternatively, you can lock the clock at a specific frequency by calling the sudo nvidia-smi -lgc <freq> command (and conversely, you can let the clock float again with the sudo nvidia-smi -rgc command). The sudo nvidia-smi -q -d SUPPORTED_CLOCKS command can find the supported clock frequencies. After the clock frequency is locked, it should stay at that frequency unless power or thermal throttling occurs, which will be explained in the next sections. When the throttling kicks in, the device behaves like the clock floats.

Running TensorRT workloads with floating clocks or with throttling taking place can lead to more non-determinism in tactic selections and unstable performance measurements across inferences because every CUDA kernel may run at slightly different clock frequencies, depending on which frequency the driver boosts or throttles the clock to at that moment. On the other hand, running TensorRT workloads with locked clocks allows more deterministic tactic selections and consistent performance measurements. Still, the average performance will not be as good as when the clock is floating or is locked at maximum frequency with throttling taking place.

There is no definite recommendation on whether the clock should be locked or which clock frequency to lock the GPU while running TensorRT workloads. It depends on whether the deterministic and stable performance or the best average performance is desired.

GPU Power Consumption and Power Throttling#

Power throttling occurs when the average GPU power consumption reaches the power limit, which can be set by the sudo nvidia-smi -pl <power_cap> command. When this happens, the driver has to throttle the clock to a lower frequency to keep the average power consumption below the limit. The constantly changing clock frequencies may lead to unstable performance measurements if the measurements are taken within a short time, such as within 20ms.

Power throttling happens by design and is a natural phenomenon when the GPU clock is not locked or is locked at a higher frequency, especially for GPUs with lower power limits, such as NVIDIA T4 and NVIDIA A2 GPUs. To avoid performance variations caused by power throttling, you can lock the GPU clock at a lower frequency to stabilize the performance numbers. However, the average performance numbers will be lower than those with floating clocks or the clock locked at a higher frequency, even though power throttling would happen in this case.

Another issue with power throttling is that it may skew the performance numbers if there are gaps between inferences in your performance benchmarking applications. For example, if the application synchronizes at each inference, there will be periods when the GPU is idle between the inferences. The gaps cause the GPU to consume less power on average, so the clock is throttled less, and the GPU can run at higher clock frequencies on average. However, the throughput numbers measured this way are inaccurate because when the GPU is fully loaded with no gaps between inferences, the actual clock frequency will be lower, and the actual throughput will not reach the throughput numbers measured using the benchmarking application.

To avoid this, the trtexec tool is designed to maximize GPU execution by leaving nearly no gaps between GPU kernel executions so that it can measure the true throughput of a TensorRT workload. Therefore, if you see performance gaps between your benchmarking application and what trtexec reports, check if the power throttling and the gaps between inferences are the cause.

Lastly, power consumption can depend on the activation values, causing different input performance measurements. For example, if all the network input values are set to zeros or NaNs, the GPU consumes less power than the inputs are normal values because of fewer bit-flips in DRAM and the L2 cache. To avoid this discrepancy, always use the input values that best represent the actual value distribution when measuring the performance. The trtexec tool uses random input values by default, but you can specify the input using the --loadInputs flag. For more information, refer to the trtexec section.

GPU Temperature and Thermal Throttling#

Thermal throttling happens when the GPU temperature reaches a predefined threshold of around 85 degrees Celsius for most GPUs, and the driver has to throttle the clock to a lower frequency to prevent the GPU from overheating. You can tell this by seeing the temperature logged by the nvidia-smi dmon command gradually increasing while the inference workload runs until it reaches ~85C and the clock frequency drops.

If thermal throttling happens on actively cooled GPUs like Quadro A8000, then it is possible that the fans on the GPU are broken or obstacles are blocking the airflow.

If thermal throttling happens on passively cooled GPUs like NVIDIA A10, then it is likely that the GPUs are not properly cooled. Passively cooled GPUs require external fans or air conditioning to cool down the GPUs, and the airflow must go through the GPUs for effective cooling. Common cooling problems include installing GPUs in a server that is not designed for the GPUs or installing the wrong numbers of GPUs into the server. In some cases, the air flows through the “easy path” (the path with the least friction) around the GPUs instead of going through them. Fixing this requires examination of the airflow in the server and installation of airflow guidance if necessary.

Note that higher GPU temperature also leads to more leakage current in the circuits, which increases the power consumed by the GPU at a specific clock frequency. Therefore, for GPUs more likely to be power throttled like NVIDIA T4, poor cooling can lead to lower stabilized clock frequency with power throttling and, thus, worse performance, even if the GPU clocks have not been thermally throttled yet.

On the other hand, ambient temperature, the environment’s temperature around the server, does not usually affect GPU performance as long as the GPUs are properly cooled, except for GPUs with lower power limits whose performance may be slightly affected.

H2D/D2H Data Transfers and PCIe Bandwidth#

On dGPU systems, the input data must often be copied from the host memory to the device memory (H2D) before an inference starts, and the output data must be copied back from the device memory to the host memory (D2H) after the inference. These H2D/D2H data transfers go through PCIe buses, and they can sometimes influence the inference performance or even become the performance bottleneck. The H2D/D2H copies can also be seen in the Nsight Systems profiles, appearing as cudaMemcpy() or cudaMemcpyAsync() CUDA API calls.

To achieve maximum throughput, the H2D/D2H data transfers should run parallel to the GPU executions of other inferences so that the GPU does not sit idle when the H2D/D2H copies occur. This can be done by running multiple inferences in parallel streams or launching H2D/D2H copies in a different stream than the stream used for GPU executions and using CUDA events to synchronize between the streams. The trtexec tool shows an example of the latter implementation.

When the H2D/D2H copies run parallel to GPU executions, they can interfere with the GPU executions, especially if the host memory is pageable, which is the default case. Therefore, it is recommended that you allocate pinned host memory for the input and output data using cudaHostAlloc() or cudaMallocHost() CUDA APIs.

To check whether the PCIe bandwidth becomes the performance bottleneck, you can check the Nsight Systems profiles and see if the H2D or D2H copies of an inference query have longer latencies than the GPU execution part. If PCIe bandwidth becomes the performance bottleneck, here are a few possible solutions.

First, check whether the PCIe bus configuration of the GPU is correct in terms of which generation (for example, Gen3 or Gen4) and how many lanes (for example, x8 or x16) are used. Next, reduce the amount of data that must be transferred using the PCIe bus. For example, suppose the input images have high resolutions, and the H2D copies become the bottleneck. In that case, you can transmit JPEG-compressed images over the PCIe bus and decode the image on the GPUs before the inference workflow instead of transmitting raw pixels. Finally, consider using NVIDIA GPUDirect technology to load data directly from/to the network or the filesystems without going through the host memory.

In addition, if your system has AMD x86_64 CPUs, check the machine’s NUMA (Non-Uniform Memory Access) configurations with the numactl --hardware command. The PCIe bandwidth between a host memory and a device memory located on two different NUMA nodes is much more limited than the bandwidth between the host/device memory located on the same NUMA node. Allocate the host memory on the NUMA node on which the GPU on which the data will be copied resides. Also, pin the CPU threads that trigger the H2D/D2H copies on that specific NUMA node.

Note that the host and the device share the same memory on mobile platforms, so the H2D/D2H data transfers are not required if the host memory is allocated using CUDA APIs and is pinned instead of pageable.

By default, the trtexec tool measures the latencies of the H2D/D2H data transfers, which tells the user if the H2D/D2H copies may bottleneck the TensorRT workload. However, the H2D/D2H copies affect the stability of the GPU Compute Time. In that case, you can add the --noDataTransfers flag to disable H2D/D2H transfers and measure only the latencies of the GPU execution part.

TCC Mode and WDDM Mode#

On Windows machines, there are two driver modes: you can configure the GPU to be in the TCC mode and the WDDM mode. The mode can be specified by calling the sudo nvidia-smi -dm [0|1] command, but a GPU connected to a display shall not be configured into TCC mode. For more information, refer to the TCC mode documentation.

In TCC mode, the GPU is configured to focus on computation work, and graphics support like OpenGL or monitor display is disabled. This is the recommended mode for GPUs that run TensorRT inference workloads. On the other hand, the WDDM mode tends to cause GPUs to have worse and unstable performance results when running inference workloads using TensorRT.

This does not apply to Linux-based OS.

Enqueue-Bound Workloads and CUDA Graphs#

The enqueueV3() function of IExecutionContext is asynchronous. That is, it returns immediately after all the CUDA kernels are launched without waiting for the completion of CUDA kernel executions. However, in some cases, the enqueueV3() time can take longer than the actual GPU executions, causing the latency of enqueueV3() calls to become the performance bottleneck. We say that this type of workload is “enqueue-bound.” Two reasons may cause a workload to be enqueue-bound.

First, if the workload is very tiny in terms of the number of computations, such as containing convolutions with small I/O sizes, matrix multiplications with small GEMM sizes, or mostly element-wise operations throughout the network, then the workload tends to be enqueue-bound. This is because most CUDA kernels take the CPU and the driver around 5-15 microseconds to launch per kernel, so if each CUDA kernel execution time is only several microseconds long on average, the kernel launching time becomes the main performance bottleneck.

To solve this, try increasing the computation per CUDA kernel by increasing the batch size. You can also use CUDA Graphs to capture the kernel launches into a graph and launch the graph instead of calling enqueueV3().

Second, it is naturally queue-bound if the workload contains operations requiring device synchronizations, such as loops or if-else conditions. Increasing the batch size may help improve the throughput without increasing the latency.

In trtexec, you can tell that a workload is enqueue-bound if the reported Enqueue Time is close to or longer than the reported GPU Compute Time. In this case, it is recommended that you add the --useCudaGrap``h flag to enable CUDA graphs in ``trtexec, which will reduce the Enqueue Time as long as the workload does not contain any synchronization operations.

BlockingSync and SpinWait Synchronization Modes#

If performance is measured with cudaStreamSynchronize() or cudaEventSynchronize(), synchronization overhead variations may lead to performance measurement variations. This section describes the causes of the variations and how to avoid them.

When cudaStreamSynchronize() is called, there are two ways in which the driver waits until the stream is completed. If the cudaDeviceScheduleBlockingSync flag has been set with cudaSetDeviceFlags() calls, then the cudaStreamSynchornize() uses the blocking-sync mechanism. Otherwise, it uses the spin-wait mechanism.

A similar idea applies to CUDA events. If a CUDA event is created with the cudaEventDefault flag, then the cudaEventSynchronize() call uses the spin-wait mechanism. If a CUDA event is created with the cudaEventBlockingSync flag, then the cudaEventSynchronize() call will use the blocking-sync mechanism.

When the blocking-sync mode is used, the host thread yields to another thread until the device work is done. This allows the CPUs to sit idle to save power or to be used by other CPU workloads when the device is still executing. However, the blocking-sync mode tends to result in relatively unstable overheads in stream/event synchronizations in some OS, leading to variations in latency measurements.

On the other hand, when the spin-wait mode is used, the host thread is constantly polling until the device work is done. Using spin-wait makes the latency measurements more stable due to shorter and more stable overhead in stream/event synchronizations. Still, it consumes some CPU computation resources and leads to more power consumption by the CPUs.

Therefore, if you want to reduce CPU power consumption or do not want the stream/event synchronizations to consume CPU resources (for example, you are running other heavy CPU workloads in parallel), use the blocking-sync mode. If you care more about stable performance measurements, use the spin-wait mode.

In trtexec, the default synchronization mechanism is in blocking-sync mode. Add the --useSpinWait flag to enable synchronizations using the spin-wait mode for more stable latency measurements at the cost of more CPU utilizations and power consumptions.

Optimizing TensorRT Performance#

The following sections focus on the general inference flow on GPUs and some general strategies to improve performance. These ideas apply to most CUDA programmers but may not be as obvious to developers from other backgrounds.

Batching#

The most important optimization is to compute as many results in parallel as possible using batching. In TensorRT, a batch is a collection of inputs that can all be processed uniformly. Each instance in the batch has the same shape and flows through the network similarly. Therefore, each instance can be trivially computed in parallel.

Each network layer will have some overhead and synchronization required to compute forward inference. By computing more results in parallel, this overhead is paid off more efficiently. In addition, many layers are performance-limited by the smallest dimension in the input. If the batch size is one or small, this size can often be the performance-limiting dimension. For example, the fully connected layer with V inputs and K outputs can be implemented for one batch instance as a matrix multiplied by a 1xV matrix with a VxK weight matrix. If N instances are batched, this becomes an NxV multiplied by the VxK matrix. The vector-matrix multiplier becomes a matrix-matrix multiplier, which is much more efficient.

Larger batch sizes are almost always more efficient on the GPU. Extremely large batches, such as N > 2^16, can sometimes require extended index computation and should be avoided if possible. But generally, increasing the batch size improves total throughput. In addition, when the network contains MatrixMultiply layers, batch sizes of multiples of 32 tend to have the best performance for FP16 and INT8 inference because of the utilization of Tensor Cores if the hardware supports them.

On NVIDIA Ada Lovelace or later GPUs, decreasing the batch size may improve the throughput significantly if the smaller batch sizes help the GPU cache the input/output values in the L2 cache. Therefore, various batch sizes should be tried to find the batch size that provides optimal performance.

Sometimes, batching inference work is impossible due to the application’s organization. In some common applications, such as a server that makes inferences per request, it is possible to implement opportunistic batching. For each incoming request, wait for a time T. If other requests come in, batch them together. Otherwise, continue with a single-instance inference. This strategy adds fixed latency to each request but can greatly improve the system’s maximum throughput.

The NVIDIA Triton Inference Server provides a simple way to enable dynamic batching with TensorRT engines.

Using Batching

The batch dimension is part of the tensor dimensions, and you can specify the range of the batch sizes and the batch size to optimize the engine by adding optimization profiles. For more information, refer to the Working With Dynamic Shapes section.

Within-Inference Multi-Streaming#

In general, CUDA programming streams are a way of organizing asynchronous work. Asynchronous commands put into a stream are guaranteed to run in sequence but may execute out of order concerning other streams. In particular, asynchronous commands in two streams may be scheduled to run concurrently (subject to hardware limitations).

In the context of TensorRT and inference, each layer of the optimized final network will require work on the GPU. However, not all layers can fully use the hardware’s computation capabilities. Scheduling requests in separate streams allows work to be scheduled immediately as the hardware becomes available without unnecessary synchronization. Even if only some layers can be overlapped, overall performance will improve.

Use the IBuilderConfig::setMaxAuxStreams() API to set the maximum number of auxiliary streams TensorRT can use to run multiple layers in parallel. The auxiliary streams contrast the “mainstream” provided in the enqueueV3() call. If enabled, TensorRT will run some layers on the auxiliary streams parallel to those running on the mainstream.

For example, to run the inference on at most eight streams (that is, seven auxiliary streams and one mainstream) in total:

1config->setMaxAuxStreams(7)
1config.max_aux_streams = 7

Note that this only sets the maximum number of auxiliary streams. However, TensorRT may use fewer auxiliary streams than this number if it determines that using more streams does not help.

To get the actual number of auxiliary streams that TensorRT uses for an engine, run the following:

1int32_t nbAuxStreams = engine->getNbAuxStreams()
1num_aux_streams = engine.num_aux_streams

When an execution context is created from the engine, TensorRT automatically creates the auxiliary streams needed to run the inference. However, you can also specify the auxiliary streams you would like TensorRT to use:

1int32_t nbAuxStreams = engine->getNbAuxStreams();
2std::vector<cudaStream_t> streams(nbAuxStreams);
3for (int32_t i = 0; i < nbAuxStreams; ++i)
4{
5    cudaStreamCreate(&streams[i]);
6}
7context->setAuxStreams(streams.data(), nbAuxStreams);
1from cuda import cudart
2num_aux_streams = engine.num_aux_streams
3streams = []
4for i in range(num_aux_streams):
5    err, stream = cudart.cudaStreamCreate()
6    streams.append(stream)
7context.set_aux_streams(streams)

TensorRT will always insert event synchronizations between the mainstream provided using enqueueV3() call and the auxiliary streams:

  • At the beginning of the enqueueV3() call, TensorRT will ensure that all the auxiliary streams wait on the activities on the mainstream.

  • At the end of the enqueueV3() call, TensorRT will ensure that the mainstream waits for the activities on all the auxiliary streams.

Enabling auxiliary streams may increase memory consumption because some activation buffers can no longer be reused.

Cross-Inference Multi-Streaming#

In addition to the within-inference streaming, you can enable streaming between multiple execution contexts. For example, you can build an engine with multiple optimization profiles and create an execution context per profile. Then, call the enqueueV3() function of the execution contexts on different streams to allow them to run in parallel.

Running multiple concurrent streams often leads to several streams sharing compute resources simultaneously. This means the network may have fewer compute resources available during inference than when the TensorRT engine was optimized. This difference in resource availability can cause TensorRT to choose a suboptimal kernel for the actual runtime conditions. To mitigate this effect, you can limit the amount of available compute resources during engine creation to resemble actual runtime conditions more closely. This approach generally promotes throughput at the expense of latency. For more information, refer to the Limiting Compute Resources section.

It is also possible to use multiple host threads with streams. A common pattern is incoming requests dispatched to a pool of worker threads waiting for work. In this case, the pool of worker threads will each have one execution context and CUDA stream. Each thread will request work in its stream as the work becomes available. Each thread will synchronize with its stream to wait for results without blocking other worker threads.

CUDA Graphs#

CUDA Graphs represent a sequence (or, more generally, a graph) of kernels in a way that allows CUDA to optimize their scheduling. This can be particularly useful when your application performance is sensitive to the CPU time to queue the kernels.

Using CUDA Graphs with TensorRT Execution Context#

TensorRT’s enqueueV3() method supports CUDA graph capture for models requiring no mid-pipeline CPU interaction. For example:

 1// Call enqueueV3() once after an input shape change to update internal state.
 2context->enqueueV3(stream);
 3
 4// Capture a CUDA graph instance
 5cudaGraph_t graph;
 6cudaGraphExec_t instance;
 7cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
 8context->enqueueV3(stream);
 9cudaStreamEndCapture(stream, &graph);
10cudaGraphInstantiate(&instance, graph, 0);
11
12// To run inferences, launch the graph instead of calling enqueueV3().
13for (int i = 0; i < iterations; ++i) {
14    cudaGraphLaunch(instance, stream);
15    cudaStreamSynchronize(stream);
16}
 1from cuda import cudart
 2err, stream = cudart.cudaStreamCreate()
 3
 4# Call execute_async_v3() once after an input shape change to update internal state.
 5context.execute_async_v3(stream);
 6
 7# Capture a CUDA graph instance
 8cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureModeGlobal)
 9context.execute_async_v3(stream)
10err, graph = cudart.cudaStreamEndCapture(stream)
11err, instance = cudart.cudaGraphInstantiate(graph, 0)
12
13# To run inferences, launch the graph instead of calling execute_async_v3().
14for i in range(iterations):
15    cudart.cudaGraphLaunch(instance, stream)
16    cudart.cudaStreamSynchronize(stream)

Limitations of CUDA Graphs#

CUDA graphs cannot handle some operations, so graph capturing may fail if the execution context contains such operations. Typical deep learning operators unsupported by CUDA graphs include loops, conditionals, and layers requiring data-dependent shapes. In these cases, cudaStreamEndCapture() will return cudaErrorStreamCapture* errors, indicating that the graph capturing has failed, but the context can continue to be used for normal inference without CUDA graphs. Refer to the CUDA Programming Guide to learn more about the limitations of CUDA graphs.

Also, when capturing a graph, it is important to account for the two-phase execution strategy used in the presence of dynamic shapes.

  1. Update the model’s internal state to account for any changes in input size.

  2. Stream work to the GPU.

The first phase requires no per-invocation work for models where input size is fixed at build time. Otherwise, if the input sizes have changed since the last invocation, some work may be required to update derived properties.

The first phase of work is not designed to be captured, and even if the capture is successful, it may increase model execution time. Therefore, after changing the shapes of inputs or the values of shape tensors, call enqueueV3() once to flush deferred updates before capturing the graph.

Graphs captured with TensorRT are specific to the input size and the state of the execution context. Modifying the context from which the graph was captured will result in undefined behavior when executing the graph—in particular, if the application is providing its memory for activations using createExecutionContextWithoutDeviceMemory(), the memory address is also captured as part of the graph. Locations of input and output buffers are also captured as part of the graph.

Therefore, the best practice is to use one execution context per captured graph and to share memory across the contexts with createExecutionContextWithoutDeviceMemory().

trtexec allows you to check whether the TensorRT engine you built is compatible with CUDA graph capture. For more information, refer to the trtexec section.

Concurrent CUDA Activities with CUDA Graph Capture#

Launching a CUDA kernel on the CUDA legacy default stream or calling synchronous CUDA APIs like cudaMemcpy() while capturing a CUDA graph fails because these CUDA activities implicitly synchronize the CUDA streams used by TensorRT execution contexts.

To avoid breaking the CUDA graph capture, ensure other CUDA kernels are launched on non-default CUDA streams and use the asynchronous version of CUDA APIs, like cudaMemcpyAsync().

Alternatively, a CUDA stream can be created with the cudaStreamNonBlocking flag to capture the CUDA graph for an execution context. If the execution context uses auxiliary streams, make sure you also call the setAuxStreams() API using streams created with the cudaStreamNonBlocking flag. Refer to the Within-Inference Multi-Streaming section about how to set auxiliary streams in TensorRT execution contexts.

Enabling Fusion#

Layer Fusion#

TensorRT attempts to perform many different types of optimizations in a network during the build phase. In the first phase, layers are fused whenever possible. Fusions transform the network into a simpler form but preserve the same overall behavior. Internally, many layer implementations have extra parameters and options that are not directly accessible when creating the network. Instead, the fusion optimization step detects supported patterns of operations and fuses multiple layers into one layer with an internal options set.

Consider the common case of a convolution followed by ReLU activation. Creating a network with these operations involves adding a Convolution layer with addConvolutionNd and following it with an Activation layer using addActivation with an ActivationType of kRELU. The unoptimized graph will contain separate layers for convolution and activation. The internal implementation of convolution supports computing the ReLU function on the output in one step directly from the convolution kernel without requiring a second kernel call. The fusion optimization step will detect the convolution followed by ReLU. Verify that the implementation supports the operations, then fuse them into one layer.

To investigate which fusions have occurred, the builder logs its operations to the logger object provided during construction. Optimization steps are at the kINFO log level. To see these messages, ensure you log them in the ILogger callback.

Fusions are normally handled by creating a new layer with a name containing the names of both of the layers that were fused. For example, a MatrixMultiply layer (InnerProduct) named ip1 is fused with a ReLU Activation layer named relu1 to create a new layer named ip1 + relu1.

Types of Fusions#

The following list describes the types of supported fusions.

Supported Layer Fusions

  • ReLU Activation: A single activation layer will replace an Activation layer performing ReLU followed by an activation performing ReLU.

  • Convolution and ReLU Activation: The Convolution layer can be of any type, and values are not restricted. The Activation layer must be of the ReLU type.

  • Convolution and GELU Activation: The input and output precision should be the same, with both of them FP16 or INT8. The Activation layer must be GELU type. TensorRT should run on an NVIDIA Turing or later with CUDA version 10.0.

  • Convolution and Clip Activation: The Convolution layer can be any type, and values are not restricted. The Activation layer must be Clip type.

  • Scale and Activation: The Scale layer, followed by an Activation layer, can be fused into a single Activation layer.

  • Convolution and ElementWise Operation: A Convolution layer followed by a simple sum, min, or max in an ElementWise layer can be fused into the Convolution layer. The sum must not use broadcasting unless the broadcasting is across the batch size.

  • Padding and Convolution/Deconvolution: If all the padding sizes are non-negative, padding followed by a Convolution or Deconvolution can be fused into a single Convolution/Deconvolution layer.

  • Shuffle and Reduce: A Shuffle layer without reshaping, followed by a Reduce layer, can be fused into a single Reduce layer. The Shuffle layer can perform permutations but cannot perform any reshape operation. The Reduce layer must have a keepDimensions set of dimensions.

  • Shuffle and Shuffle: Each Shuffle layer consists of a transpose, a reshape, and a second transpose. A Shuffle layer followed by another can be replaced by a single Shuffle (or nothing). If both Shuffle layers perform reshape operations, this fusion is only allowed if the second transpose of the first shuffle is the inverse of the first transpose of the second shuffle.

  • Scale: A Scale layer that adds 0, multiplied by 1, or computes powers to the 1 can be erased.

  • Convolution and Scale: Adjusting the convolution weights can fuse a convolution layer followed by a Scale layer that is kUNIFORM or kCHANNEL into a single convolution. This fusion is disabled if the scale has a non-constant power parameter.

  • Convolution and Generic Activation: This fusion happens after the pointwise fusion mentioned below. A pointwise with one input and output can be called a generic activation layer. A convolution layer followed by a generic activation layer can be fused into a single convolution layer.

  • Reduce: It performs average pooling, which a Pooling layer will replace. The Reduce layer must have a keepDimensions set and be reduced across H and W dimensions from the CHW input format before batching using the kAVG operation.

  • Convolution and Pooling: The Convolution and Pooling layers must have the same precision. The Convolution layer may already have a fused activation operation from a previous fusion.

  • Depthwise Separable Convolution: A depthwise convolution with activation followed by a convolution with activation may sometimes be fused into a single optimized DepSepConvolution layer. The precision of both convolutions must be INT8, and the device’s computation capability must be 7.2 or later.

  • Softmax and Log: If it has not already been fused with a previous log operation, it can be fused into a single Softmax layer.

  • Softmax and TopK: It can be fused into a single layer. The Softmax may or may not include a Log operation.

Supported Reduction Operation Fusions

  • GELU: A group of Unary and ElementWise layers representing the following equations can be fused into a single GELU reduction operation.

    \(0.5x\times \left( 1+\tanh\left( \frac{2}{π}\left( x+0.044715x^{3} \right) \right) \right)\)

    Or the alternative representation:

    \(0.5x \times \left( 1+erf\left( \frac{x}{\sqrt{2}} \right) \right)\)

  • L1Norm: A Unary layer kABS operation followed by a Reduce layer kSUM operation can be fused into a single L1Norm reduction operation.

  • Sum of Squares: A product ElementWise layer with the same input (square operation) followed by a kSUM reduction can be fused into a single square sum reduction operation.

  • L2Norm: A sum of squares operation followed by a kSQRT UnaryOperation can be fused into a single L2Norm reduction operation.

  • LogSum: A Reduce layer kSUM followed by a kLOG UnaryOperation can be fused into a single LogSum reduction operation.

  • LogSumExp: A Unary kEXP ElementWise operation followed by a LogSum fusion can be fused into a single LogSumExp reduction operation.

Pointwise Fusion#

Multiple adjacent Pointwise layers can be fused into a single Pointwise layer to improve performance.

The following types of Pointwise layers are supported, with some limitations:

  • Activation: Every ActivationType is supported.

  • Constant: Only constant with a single value (size == 1).

  • ElementWise: Every ElementWiseOperation is supported.

  • Pointwise: Pointwise itself is also a Pointwise layer.

  • Scale: Only support ScaleMode::kUNIFORM.

  • Unary: Every UnaryOperation is supported.

The size of the fused Pointwise layer is not unlimited, so some layers may not be fused.

Fusion creates a new layer with a name consisting of both fused layers. For example, an ElementWise layer named add1 is fused with a ReLU Activation layer named relu1, creating a new layer named fusedPointwiseNode(add1, relu1).

Q/DQ Fusion#

Refer to the Explicit Quantization section for suggestions on optimizing INT8 and FP8 networks containing QuantizeLinear and DequantizeLinear layers.

Multi-Head Attention Fusion#

We highly recommend tailoring your model to said restrictions so that Multi-Head Attention (MHA) fusion happens. It is important because it supports large sequence lengths by significantly reducing memory footprint from O(S^2) to O(S), where S is the sequence length. On top of that, it shares the common performance benefits of operator fusion, that is, reduced memory traffic, better hardware utilization, less kernel launch, and synchronization overhead.

Multi-head attention (MHA) computes softmax(Q * K^T / scale + mask) * V, where:

  • Q is query embedding

  • K is key embedding

  • V is value embeddings

The shape of Q is [B, N, S_q, H], and the shapes of K and V are [B, N, S_kv, H], where:

  • B is batch size

  • N is the number of attention heads

  • H is the head/hidden size

  • S_q and S_kv are the sequence lengths of the query and key/value, respectively.

MHA Fusion Pattern for FP16 and BF16

TensorRT chooses the accumulation precision by default based on the input types and performance considerations. However, you can also control accumulation precision (refer to Control of Computational Precision).

MHA Fusion Pattern for FP8 and INT8

The MHA fusion captures common pointwise operators in series in MHA. These include scaling (pointwise multiply) and masking operators (pointwise addition and select) at the pointwise op list. It also covers Q/DQ fusion following MHA for certain quantization and architecture (FP16/BF16 to FP8/INT8 on Ampere).

Supported MHA Fusion Types#

Feature

FP16

BF16

INT8

FP8

SM Version (V)

SM75 V SM90

SM80 V SM90

SM75 V SM90

  • SM89

  • SM90

Head Size (H)

  • 16 H 256

  • H % 8 ==0

  • 16 H 256

  • H % 8 ==0

  • 16

  • 32

  • 64

  • 32 H 256

  • H % 16 == 0

Sequence Length (S_q, S_kv)

No restriction

No restriction

S_{q,kv} 512

S_{q} =1 or 32 S_{q,kv}

Quantization

Not required

Not required

Specify Q/DQ layers in the MHA pattern for FP8 and INT8.

Specify Q/DQ layers in the MHA pattern for FP8 and INT8.

Accumulation Precision (BMM1)

  • FP16

  • FP32

FP32

INT32

FP32

Accumulation Precision (BMM2)

  • FP16

  • FP32 (if SM90 and S_q = S_kv)

FP32

INT32

FP32

Supported Mask Type

Any masking

Any masking

Any masking

Any masking

Any masking means the masking can employ the Select operator in TensorRT. TensorRT may decide not to fuse an MHA graph into a single kernel based on performance evaluations or other constraints.

Supported MHA Fusions on SM100#

Feature

FP16

BF16

FP8

SM Version (V)

SM100

SM100

SM100

Head Size (H)

  • 8 H 128

  • H % 8 ==0

  • 8 H 128

  • H % 8 ==0

  • 16 H 128

  • H % 16 == 0

Sequence Length (S_q, S_kv)

No restriction

No restriction

No restriction

Quantization

Not required

Not required

Specify Q/DQ layers in the MHA pattern for FP8 and INT8.

Accumulation Precision (BMM1)

FP32

FP32

FP32

Accumulation Precision (BMM2)

FP32

FP32

FP32

Example Workflow: FP8 MHA Fusion#

Assume you have an ONNX model, vit_base_patch8_224_Opset17.onnx, and calibration data, calib.npy, on your local machine.

  1. Install the TensorRT model optimizer.

    pip3 install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-modelop
    
  2. Quantize a model with TensorRT model optimizer. For more information, refer to these detailed instructions.

    python3 -m modelopt.onnx.quantization \
    --onnx_path=vit_base_patch8_224_Opset17.onnx \
    --quantize_mode=<fp8|int8> \
    --calibration_data=calib.npy \
    --calibration_method=<max|entropy> \
    --output_path=vit_base_patch8_224_Opset17.quant.onnx
    
  3. Compile the quantized model with TensorRT.

    trtexec --onnx=vit_base_patch8_224_Opset17.quant.onnx \
    --saveEngine=vit_base_patch8_224_Opset17.engine \
    --stronglyTyped --skipInference --profilingVerbosity=detailed
    
  4. Run the quantized model with TensorRT.

    trtexec --loadEngine=vit_base_patch8_224_Opset17.engine \
    --useCudaGraph --noDataTransfers --useSpinWait
    

    Add the following options if you want to check if MHA is fused. MHA should be fused if you find the mha op in the output.log file.

    trtexec --loadEngine=vit_base_patch8_224_Opset17.engine \
    --profilingVerbosity=detailed --dumpLayerInfo --skipInference &> output.log
    

    Tip

    • There are two ways to set the accumulation data type to FP32:

      1. Manually set computational precision. For more information, refer to these detailed instructions.

      2. Convert your ONNX model using TensorRT Model Optimizer, which adds the Cast ops automatically.

        • If the MHA has a head size (H) that is not a multiple of 16, do not add Q/DQ ops in the MHA to fall back to the FP16 MHA for better performance.

        • Given the restrictions, compare INT8 with FP8 for MHA fusion.

Limiting Compute Resources#

Limiting the number of compute resources available to TensorRT during engine creation is beneficial when the reduced amount better represents the expected conditions during runtime. For example, when the GPU is expected to be performing additional work in parallel to the TensorRT engine or when the engine is expected to be run on a different GPU with fewer resources (note that the recommended approach is to build the engine on the GPU that will be used for inference, but this may not always be feasible).

You can limit the number of available compute resources with the following steps:

  1. Start the CUDA MPS control daemon.

    nvidia-cuda-mps-control -d
    
  2. Set the number of computing resources to use with the CUDA_MPS_ACTIVE_THREAD_PERCENTAGE environment variable. For example, export CUDA_MPS_ACTIVE_THREAD_PERCENTAGE=50.

  3. Build the network engine.

  4. Stop the CUDA MPS control daemon.

    echo quit | nvidia-cuda-mps-control
    

The resulting engine is optimized to the reduced number of compute cores (50% in this example) and provides better throughput when using similar conditions during inference. You are encouraged to experiment with different amounts of streams and different MPS values to determine the best performance for your network.

For more details about nvidia-cuda-mps-control, refer to the nvidia-cuda-mps-control documentation and the relevant GPU requirements here.

Deterministic Tactic Selection#

TensorRT runs through all the possible tactics in the engine-building phase and selects the fastest ones. Since the selection is based on the tactics’ latency measurements, TensorRT may select different tactics across different runs if some have similar latencies. Therefore, different engines built from the same INetworkDefinition may behave slightly differently regarding output values and performance. You can inspect the selected tactics of an engine by using the engine inspector APIs or by turning on verbose logging while building the engine.

If deterministic tactic selection is desired, the following lists a few suggestions that may help improve the determinism of tactic selection.

Locking GPU Clock Frequency

By default, the GPU’s clock frequency is not locked, meaning that the GPU normally sits at the idle clock frequency and only boosts to the max clock frequency when there are active GPU workloads. However, there is a latency for the clock to be boosted from the idle frequency, and that may cause performance variations while TensorRT is running through the tactics and selecting the best ones, resulting in non-deterministic tactic selections.

Therefore, locking the GPU clock frequency before building a TensorRT engine may improve the determinism of tactic selection. Refer to the Hardware/Software Environment for Performance Measurements section for more information about how to lock and monitor the GPU clock and the factors that may affect GPU clock frequencies.

Increasing Average Timing Iterations

By default, TensorRT runs each tactic for at least four iterations and takes the average latency. You can increase the number of iterations by calling the setAvgTimingIterations() API:

1builderConfig->setAvgTimingIterations(8);
1Builder_config.avg_timing_iterations = 8

Increasing the number of average timing iterations may improve the determinism of tactic selections, but the required engine-building time will become longer.

Using Timing Cache

Timing Cache records the latencies of each tactic for a specific layer configuration. The tactic latencies are reused if TensorRT encounters another layer with an identical configuration. Therefore, by reusing the same timing cache across multiple engine buildings runs with the same INetworkDefinition and builder config, you can make TensorRT select an identical set of tactics in the resulting engines.

Overhead of Shape Change and Optimization Profile Switching#

After the IExecutionContext switches to a new optimization profile or the shapes of the input bindings change, TensorRT must recompute the tensor shapes throughout the network and recompute the resources needed by some tactics for the new shapes before the next inference can start. That means the first enqueueV3() call after a shape/profile change may be longer than the subsequent enqueueV3() calls.

Optimizing the cost of shape/profile switching is an active development area. However, there are still a few cases where the overhead can influence the performance of the inference applications. For example, some convolution tactics for NVIDIA Volta GPUs or older GPUs have much longer shape/profile switching overhead, even if their inference performance is the best among all the available tactics. In this case, disabling kEDGE_MASK_CONVOLUTIONS tactics from tactic sources when building the engine may reduce the overhead of shape/profile switching.

Optimizing Layer Performance#

The following descriptions detail how you can optimize the listed layers.

  • Gather: Use an axis of 0 to maximize the performance of a Gather layer. There are no fusions available for a Gather layer.

  • Reduce: To get the maximum performance out of a Reduce layer, perform the reduction across the last dimensions (tail reduce). This allows optimal memory to read/write patterns through sequential memory locations. If doing common reduction operations, express the reduction in a way that will be fused to a single operation.

  • RNN: Loop-based API provides a much more flexible mechanism for using general layers within recurrence. The ILoopLayer recurrence enables a rich set of automatic loop optimizations, including loop fusion, unrolling, and loop-invariant code motion, to name a few. For example, significant performance gains are often obtained when multiple instances of the same MatrixMultiply layer are properly combined to maximize machine utilization after loop unrolling along the sequence dimension. This works best if you can avoid a MatrixMultiply layer with a recurrent data dependence along the sequence dimension.

  • Shuffle: Shuffle operations equivalent to identity operations on the underlying data are omitted if the input tensor is only used in the shuffle layer and the input and output tensors of this layer are not input and output tensors of the network. TensorRT does not execute additional kernels or memory copies for such operations.

  • TopK: To get the maximum performance out of a TopK layer, use small values of K, reducing the last dimension of data to allow optimal sequential memory access. Reductions along multiple dimensions at once can be simulated using a Shuffle layer to reshape the data and then appropriately reinterpret the index values.

For more information about layers, refer to the TensorRT Operator documentation.

Optimizing for Tensor Cores#

Tensor Core is a key technology for delivering high-performance inference on NVIDIA GPUs. In TensorRT, Tensor Core operations are supported by all compute-intensive layers: MatrixMultiply, Convolution, and Deconvolution.

Tensor Core layers tend to achieve better performance if the I/O tensor dimensions are aligned to a certain minimum granularity:

  • The alignment requirement is on the I/O channel dimension in the Convolution and Deconvolution layers.

  • In the MatrixMultiply layer, the alignment requirement is on matrix dimensions K and N in a MatrixMultiply that is M x K times K x N.

The following table captures the suggested tensor dimension alignment for better Tensor Core performance.

Types of Tensor Cores#

Tensor Core Operation Type

Suggested Tensor Dimension Alignment in Elements

TF32

4

FP16

8 for dense math, 16 for sparse math

INT8

32

When using Tensor Core implementations in cases where these requirements are unmet, TensorRT implicitly pads the tensors to the nearest multiple of alignment, rounding up the dimensions in the model definition instead to allow for extra capacity in the model without increasing computation or memory traffic.

TensorRT always uses the fastest implementation for a layer, and thus, in some cases, it may not use a Tensor Core implementation even if it is available.

To check if Tensor Core is used for a layer, run Nsight Systems with the --gpu-metrics-device all flag while profiling the TensorRT application. The Tensor Core usage rate can be found in the profiling result in the Nsight Systems user interface under the SM instructions/Tensor Active row. Refer to the CUDA Profiling Tools for more information about using Nsight Systems to profile TensorRT applications.

It is impractical to expect a CUDA kernel to reach 100% Tensor Core usage since there are other overheads such as DRAM reads/writes, instruction stalls, other computation units, etc. The more computation-intensive an operation is, the higher the Tensor Core usage rate the CUDA kernel can achieve.

The following image is an example of Nsight Systems profiling.

Tensor Core Activities on an A100 GPU Running ResNet-50 with FP16 Enabled

Optimizing Plugins#

TensorRT provides a mechanism for registering custom plugins that perform layer operations. After a plugin creator is registered, you can search the registry to find the creator and add the corresponding plugin object to the network during serialization/deserialization.

Once the plugin library is loaded, all TensorRT plugins are automatically registered. For more information about custom plugins, refer to Extending TensorRT With Custom Layers.

Plugin performance depends on the CUDA code performing the plugin operation. Standard CUDA Best Practices apply. When developing plugins, starting with simple standalone CUDA applications that perform the plugin operation and verify correctness can be helpful. The plugin program can then be extended with performance measurements, more unit testing, and alternate implementations. After the code is working and optimized, it can be integrated as a plugin into TensorRT.

Supporting as many formats as possible in the plugin is important to get the best performance possible. This removes the need for internal reformat operations during the execution of the network. Refer to the Extending TensorRT With Custom Layers section for examples.

Optimizing Python Performance#

Most of the same performance considerations apply when using the Python API. When building engines, the builder optimization phase will normally be the performance bottleneck, not API calls to construct the network. Inference time should be nearly identical between the Python API and C++ API.

Setting up the input buffers in the Python API involves using pycuda or another CUDA Python library, like cupy, to transfer the data from the host to device memory. The details of how this works will depend on where the host data comes from. Internally, pycuda supports the Python Buffer Protocol, allowing efficient access to memory regions. This means that if the input data is available in a suitable format in numpy arrays or another type with support for the buffer protocol, it allows efficient access and transfer to the GPU. For even better performance, allocate a page-locked buffer using pycuda and write your final preprocessed input.

For more information about using the Python API, refer to the Python API documentation.

Improving Model Accuracy#

Depending on the builder configuration, TensorRT can execute a layer in FP32, FP16, BF16, FP8, or INT8 precision. By default, TensorRT chooses to run a layer in a precision that results in optimal performance. Sometimes, this can result in poor accuracy. Generally, running a higher-precision layer helps improve accuracy with some performance hits.

There are several steps that we can take to improve model accuracy:

  1. Validate layer outputs:

    1. Use Polygraphy to dump layer outputs and verify no NaNs or Infs. The --validate option can check for NaNs and Infs. Also, we can compare layer outputs with golden values from, for example, ONNX runtime.

    2. For FP16 and BF16, a model might require retraining to ensure that intermediate layer output can be represented in FP16/BF16 precision without overflow or underflow.

    3. For INT8, consider recalibrating with a more representative calibration data set. If your model comes from PyTorch, we also provide the TensorRT Model Optimizer for QAT in the framework besides PTQ in TensorRT. You can try both approaches and choose the one with more accuracy.

  2. Manipulate layer precision:

    1. Sometimes, running a layer with a certain precision results in incorrect output. This can be due to inherent layer constraints (for example, LayerNorm output should not be INT8) or model constraints (output gets diverged, resulting in poor accuracy).

    2. You can control layer execution precision and output precision.

    3. An experimental debug precision tool can help automatically find layers to run with high precision.

  3. Use the Editable Timing Cache to select a proper tactic.

    1. When accuracy changes between two built engines for the same model, it might be due to a bad tactic being selected for a layer.

    2. Use Editable Timing Cache to dump available tactics. Update the cache with a proper one.

Accuracy from run-to-run variation should not change; once the engine is built for a specific GPU, it should result in bit-accurate outputs in multiple runs. If not, file a TensorRT bug.

Optimizing Builder Performance#

The TensorRT builder profiles each layer’s available tactics to search for the fastest inference engine plan. The builder time can be long if the model has many layers or complicated topology. The following sections provide options to reduce builder time.

Timing Cache#

TensorRT creates a layer-timing cache to reduce builder time and keep the layer-profiling information. The information it contains is specific to the targeted device, CUDA, TensorRT versions, and BuilderConfig parameters that can change the layer implementation, such as BuilderFlag::kTF32 or BuilderFlag::kREFIT.

The TensorRT builder skips profiling and reuses the cached result for the repeated layers if other layers have the same IO tensor configuration and layer parameters. If a timing query misses in the cache, the builder times the layer and updates the cache.

The timing cache can be serialized and deserialized. You can load a serialized cache from a buffer using IBuilderConfig::createTimingCache:

ITimingCache* cache =
 config->createTimingCache(cacheFile.data(), cacheFile.size());

Setting the buffer size to 0 creates a new empty timing cache.

You then attach the cache to a builder configuration before building.

config->setTimingCache(*cache, false);

Due to cache misses, the timing cache can be augmented with more information during the build. After the build, it can be serialized for use with another builder.

IHostMemory* serializedCache = cache->serialize();

If a builder does not have a timing cache attached, it creates its temporary local cache and destroys it when it is done.

The compilation cache is part of the timing cache, which caches JIT-compiled code and will be serialized as part of the timing cache by default. It can be disabled by setting the BuildFlag.

config->setFlag(BuilderFlag::kDISABLE_COMPILATION_CACHE);

Note

The timing cache supports the most frequently used layer types: Convolution, Deconvolution, Pooling, SoftMax, MatrixMultiply, ElementWise, Shuffle, and tensor memory layout conversion. More layer types will be added in future releases.

Builder Optimization Level#

Set the optimization level in the builder config to adjust how long TensorRT should spend searching for tactics with potentially better performance. By default, the optimization level is 3. Setting it to a smaller value results in much faster engine building time, but the engine’s performance may be worse. On the other hand, setting it to a larger value will increase the engine building time, but the resulting engine may perform better if TensorRT can find better tactics.

For example, to set the optimization level to 0 (the fastest):

1config->setOptimizationLevel(0);
1config.optimization_level = 0