Optimizing TensorRT Performance#

The following sections focus on the general inference flow on GPUs and some general strategies to improve performance. These ideas apply to most CUDA programmers but cannot be as obvious to developers from other backgrounds.

Batching#

The most important optimization is to compute as many results in parallel as possible using batching. In TensorRT, a batch is a collection of inputs that can all be processed uniformly. Each instance in the batch has the same shape and flows through the network similarly. Therefore, each instance can be trivially computed in parallel.

Each network layer will have some overhead and synchronization required to compute forward inference. By computing more results in parallel, this overhead is paid off more efficiently. In addition, many layers are performance-limited by the smallest dimension in the input. If the batch size is one or small, this size can often be the performance-limiting dimension. For example, the fully connected layer with V inputs and K outputs can be implemented for one batch instance as a matrix multiplied by a 1xV matrix with a VxK weight matrix. If N instances are batched, this becomes an NxV multiplied by the VxK matrix. The vector-matrix multiplier becomes a matrix-matrix multiplier, which is much more efficient.

Larger batch sizes are almost always more efficient on the GPU. Extremely large batches, such as N > 2^16, can sometimes require extended index computation and should be avoided if possible. But generally, increasing the batch size improves total throughput. In addition, when the network contains MatrixMultiply layers, batch sizes of multiples of 32 tend to have the best performance for FP16 and INT8 inference because of the utilization of Tensor Cores if the hardware supports them.

On NVIDIA Ada Lovelace or later GPUs, decreasing the batch size can improve the throughput significantly if the smaller batch sizes help the GPU cache the input/output values in the L2 cache. Therefore, various batch sizes should be tried to find the batch size that provides optimal performance.

Sometimes, batching inference work is impossible due to the application’s organization. In some common applications, such as a server that makes inferences per request, it is possible to implement opportunistic batching. For each incoming request, wait for a time T. If other requests come in, batch them together. Otherwise, continue with a single-instance inference. This strategy adds fixed latency to each request but can greatly improve the system’s maximum throughput.

The NVIDIA Triton Inference Server provides a simple way to enable dynamic batching with TensorRT engines.

Using Batching

The batch dimension is part of the tensor dimensions, and you can specify the range of the batch sizes and the batch size to optimize the engine by adding optimization profiles. For more information, refer to the Working with Dynamic Shapes section.

Within-Inference Multi-Streaming#

In general, CUDA programming streams are a way of organizing asynchronous work. Asynchronous commands put into a stream are guaranteed to run in sequence but can execute out of order concerning other streams. In particular, asynchronous commands in two streams can be scheduled to run concurrently (subject to hardware limitations).

In the context of TensorRT and inference, each layer of the optimized final network will require work on the GPU. However, not all layers can fully use the hardware’s computation capabilities. Scheduling requests in separate streams allows work to be scheduled immediately as the hardware becomes available without unnecessary synchronization. Even if only some layers can be overlapped, overall performance will improve.

Use the IBuilderConfig::setMaxAuxStreams() API to set the maximum number of auxiliary streams TensorRT can use to run multiple layers in parallel. The auxiliary streams contrast the “mainstream” provided in the enqueueV3() call. If enabled, TensorRT will run some layers on the auxiliary streams parallel to those running on the mainstream.

For example, to run the inference on at most eight streams (that is, seven auxiliary streams and one mainstream) in total:

1config->setMaxAuxStreams(7)
1config.max_aux_streams = 7

Note that this only sets the maximum number of auxiliary streams. However, TensorRT can use fewer auxiliary streams than this number if it determines that using more streams does not help.

To get the actual number of auxiliary streams that TensorRT uses for an engine, run the following:

1int32_t nbAuxStreams = engine->getNbAuxStreams()
1num_aux_streams = engine.num_aux_streams

When an execution context is created from the engine, TensorRT automatically creates the auxiliary streams needed to run the inference. However, you can also specify the auxiliary streams you would like TensorRT to use:

1int32_t nbAuxStreams = engine->getNbAuxStreams();
2std::vector<cudaStream_t> streams(nbAuxStreams);
3for (int32_t i = 0; i < nbAuxStreams; ++i)
4{
5    cudaStreamCreate(&streams[i]);
6}
7context->setAuxStreams(streams.data(), nbAuxStreams);
1from cuda import cudart
2num_aux_streams = engine.num_aux_streams
3streams = []
4for i in range(num_aux_streams):
5    err, stream = cudart.cudaStreamCreate()
6    streams.append(stream)
7context.set_aux_streams(streams)

TensorRT will always insert event synchronizations between the mainstream provided using enqueueV3() call and the auxiliary streams:

  • At the beginning of the enqueueV3() call, TensorRT will ensure that all the auxiliary streams wait on the activities on the mainstream.

  • At the end of the enqueueV3() call, TensorRT will ensure that the mainstream waits for the activities on all the auxiliary streams.

Enabling auxiliary streams can increase memory consumption because some activation buffers can no longer be reused.

Cross-Inference Multi-Streaming#

In addition to the within-inference streaming, you can enable streaming between multiple execution contexts. For example, you can build an engine with multiple optimization profiles and create an execution context per profile. Then, call the enqueueV3() function of the execution contexts on different streams to allow them to run in parallel.

Running multiple concurrent streams often leads to several streams sharing compute resources simultaneously. This means the network can have fewer compute resources available during inference than when the TensorRT engine was optimized. This difference in resource availability can cause TensorRT to choose a suboptimal kernel for the actual runtime conditions. To mitigate this effect, you can limit the amount of available compute resources during engine creation to resemble actual runtime conditions more closely. This approach generally promotes throughput at the expense of latency. For more information, refer to the Limiting Compute Resources section.

It is also possible to use multiple host threads with streams. A common pattern is incoming requests dispatched to a pool of worker threads waiting for work. In this case, the pool of worker threads will each have one execution context and CUDA stream. Each thread will request work in its stream as the work becomes available. Each thread will synchronize with its stream to wait for results without blocking other worker threads.

CUDA Graphs#

CUDA Graphs represent a sequence (or, more generally, a graph) of kernels in a way that allows CUDA to optimize their scheduling. This can be particularly useful when your application performance is sensitive to the CPU time to queue the kernels.

Using CUDA Graphs with TensorRT Execution Context#

TensorRT’s enqueueV3() method supports CUDA graph capture for models requiring no mid-pipeline CPU interaction. For example:

 1// Call enqueueV3() once after an input shape change to update internal state.
 2context->enqueueV3(stream);
 3
 4// Capture a CUDA graph instance
 5cudaGraph_t graph;
 6cudaGraphExec_t instance;
 7cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
 8context->enqueueV3(stream);
 9cudaStreamEndCapture(stream, &graph);
10cudaGraphInstantiate(&instance, graph, 0);
11
12// To run inferences, launch the graph instead of calling enqueueV3().
13for (int i = 0; i < iterations; ++i) {
14    cudaGraphLaunch(instance, stream);
15    cudaStreamSynchronize(stream);
16}
 1from cuda import cudart
 2err, stream = cudart.cudaStreamCreate()
 3
 4# Call execute_async_v3() once after an input shape change to update internal state.
 5context.execute_async_v3(stream);
 6
 7# Capture a CUDA graph instance
 8cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureModeGlobal)
 9context.execute_async_v3(stream)
10err, graph = cudart.cudaStreamEndCapture(stream)
11err, instance = cudart.cudaGraphInstantiate(graph, 0)
12
13# To run inferences, launch the graph instead of calling execute_async_v3().
14for i in range(iterations):
15    cudart.cudaGraphLaunch(instance, stream)
16    cudart.cudaStreamSynchronize(stream)

Limitations of CUDA Graphs#

CUDA graphs cannot handle some operations, so graph capturing can fail if the execution context contains such operations. Typical deep learning operators unsupported by CUDA graphs include loops, conditionals, and layers requiring data-dependent shapes. In these cases, cudaStreamEndCapture() will return cudaErrorStreamCapture* errors, indicating that the graph capturing has failed, but the context can continue to be used for normal inference without CUDA graphs. Refer to the CUDA Programming Guide to learn more about the limitations of CUDA graphs.

Also, when capturing a graph, it is important to account for the two-phase execution strategy used in the presence of dynamic shapes.

  1. Update the model’s internal state to account for any changes in input size.

  2. Stream work to the GPU.

The first phase requires no per-invocation work for models where input size is fixed at build time. Otherwise, if the input sizes have changed since the last invocation, some work can be required to update derived properties.

The first phase of work is not designed to be captured, and even if the capture is successful, it can increase model execution time. Therefore, after changing the shapes of inputs or the values of shape tensors, call enqueueV3() once to flush deferred updates before capturing the graph.

Graphs captured with TensorRT are specific to the input size and the state of the execution context. Modifying the context from which the graph was captured will result in undefined behavior when executing the graph—in particular, if the application is providing its memory for activations using createExecutionContextWithoutDeviceMemory(), the memory address is also captured as part of the graph. Locations of input and output buffers are also captured as part of the graph.

Therefore, the best practice is to use one execution context per captured graph and to share memory across the contexts with createExecutionContextWithoutDeviceMemory().

trtexec allows you to check whether the TensorRT engine you built is compatible with CUDA graph capture. For more information, refer to the trtexec section.

Concurrent CUDA Activities with CUDA Graph Capture#

Launching a CUDA kernel on the CUDA legacy default stream or calling synchronous CUDA APIs like cudaMemcpy() while capturing a CUDA graph fails because these CUDA activities implicitly synchronize the CUDA streams used by TensorRT execution contexts.

To avoid breaking the CUDA graph capture, ensure other CUDA kernels are launched on non-default CUDA streams and use the asynchronous version of CUDA APIs, like cudaMemcpyAsync().

Alternatively, a CUDA stream can be created with the cudaStreamNonBlocking flag to capture the CUDA graph for an execution context. If the execution context uses auxiliary streams, make sure you also call the setAuxStreams() API using streams created with the cudaStreamNonBlocking flag. Refer to the Within-Inference Multi-Streaming section about how to set auxiliary streams in TensorRT execution contexts.

Enabling Fusion#

Layer Fusion#

TensorRT attempts to perform many different types of optimizations in a network during the build phase. In the first phase, layers are fused whenever possible. Fusions transform the network into a simpler form but preserve the same overall behavior. Internally, many layer implementations have extra parameters and options that are not directly accessible when creating the network. Instead, the fusion optimization step detects supported patterns of operations and fuses multiple layers into one layer with an internal options set.

Consider the common case of a convolution followed by ReLU activation. Creating a network with these operations involves adding a Convolution layer with addConvolutionNd and following it with an Activation layer using addActivation with an ActivationType of kRELU. The unoptimized graph will contain separate layers for convolution and activation. The internal implementation of convolution supports computing the ReLU function on the output in one step directly from the convolution kernel without requiring a second kernel call. The fusion optimization step will detect the convolution followed by ReLU. Verify that the implementation supports the operations, then fuse them into one layer.

To investigate which fusions have occurred, the builder logs its operations to the logger object provided during construction. Optimization steps are at the kINFO log level. To view these messages, ensure you log them in the ILogger callback.

Fusions are normally handled by creating a new layer with a name containing the names of both of the layers that were fused. For example, a MatrixMultiply layer (InnerProduct) named ip1 is fused with a ReLU Activation layer named relu1 to create a new layer named ip1 + relu1.

Types of Fusions#

The following list describes the types of supported fusions.

Supported Layer Fusions

  • ReLU Activation: A single activation layer will replace an Activation layer performing ReLU followed by an activation performing ReLU.

  • Convolution and ReLU Activation: The Convolution layer can be of any type, and values are not restricted. The Activation layer must be of the ReLU type.

  • Convolution and GELU Activation: The input and output precision should be the same, with both of them FP16 or INT8. The Activation layer must be GELU type. TensorRT should run on an NVIDIA Turing or later with CUDA version 10.0.

  • Convolution and Clip Activation: The Convolution layer can be any type, and values are not restricted. The Activation layer must be Clip type.

  • Scale and Activation: The Scale layer, followed by an Activation layer, can be fused into a single Activation layer.

  • Convolution and ElementWise Operation: A Convolution layer followed by a simple sum, min, or max in an ElementWise layer can be fused into the Convolution layer. The sum must not use broadcasting unless the broadcasting is across the batch size.

  • Padding and Convolution/Deconvolution: If all the padding sizes are non-negative, padding followed by a Convolution or Deconvolution can be fused into a single Convolution/Deconvolution layer.

  • Shuffle and Reduce: A Shuffle layer without reshaping, followed by a Reduce layer, can be fused into a single Reduce layer. The Shuffle layer can perform permutations but cannot perform any reshape operation. The Reduce layer must have a keepDimensions set of dimensions.

  • Shuffle and Shuffle: Each Shuffle layer consists of a transpose, a reshape, and a second transpose. A Shuffle layer followed by another can be replaced by a single Shuffle (or nothing). If both Shuffle layers perform reshape operations, this fusion is only allowed if the second transpose of the first shuffle is the inverse of the first transpose of the second shuffle.

  • Scale: A Scale layer that adds 0, multiplied by 1, or computes powers to the 1 can be erased.

  • Convolution and Scale: Adjusting the convolution weights can fuse a convolution layer followed by a Scale layer that is kUNIFORM or kCHANNEL into a single convolution. This fusion is disabled if the scale has a non-constant power parameter.

  • Convolution and Generic Activation: This fusion happens after the pointwise fusion mentioned below. A pointwise with one input and output can be called a generic activation layer. A convolution layer followed by a generic activation layer can be fused into a single convolution layer.

  • Reduce: It performs average pooling, which a Pooling layer will replace. The Reduce layer must have a keepDimensions set and be reduced across H and W dimensions from the CHW input format before batching using the kAVG operation.

  • Convolution and Pooling: The Convolution and Pooling layers must have the same precision. The Convolution layer can already have a fused activation operation from a previous fusion.

  • Depthwise Separable Convolution: A depthwise convolution with activation followed by a convolution with activation can sometimes be fused into a single optimized DepSepConvolution layer. The precision of both convolutions must be INT8, and the device’s computation capability must be 7.2 or later.

  • Softmax and Log: If it has not already been fused with a previous log operation, it can be fused into a single Softmax layer.

  • Softmax and TopK: It can be fused into a single layer. The Softmax can optionally include a Log operation.

Supported Reduction Operation Fusions

  • GELU: A group of Unary and ElementWise layers representing the following equations can be fused into a single GELU reduction operation.

    \(0.5x\times \left( 1+\tanh\left( \frac{2}{\pi}\left( x+0.044715x^{3} \right) \right) \right)\)

    Or the alternative representation:

    \(0.5x \times \left( 1+erf\left( \frac{x}{\sqrt{2}} \right) \right)\)

  • L1Norm: A Unary layer kABS operation followed by a Reduce layer kSUM operation can be fused into a single L1Norm reduction operation.

  • Sum of Squares: A product ElementWise layer with the same input (square operation) followed by a kSUM reduction can be fused into a single square sum reduction operation.

  • L2Norm: A sum of squares operation followed by a kSQRT UnaryOperation can be fused into a single L2Norm reduction operation.

  • LogSum: A Reduce layer kSUM followed by a kLOG UnaryOperation can be fused into a single LogSum reduction operation.

  • LogSumExp: A Unary kEXP ElementWise operation followed by a LogSum fusion can be fused into a single LogSumExp reduction operation.

Pointwise Fusion#

Multiple adjacent Pointwise layers can be fused into a single Pointwise layer to improve performance.

The following types of Pointwise layers are supported, with some limitations:

  • Activation: Every ActivationType is supported.

  • Constant: Only constant with a single value (size == 1).

  • ElementWise: Every ElementWiseOperation is supported.

  • Pointwise: Pointwise itself is also a Pointwise layer.

  • Scale: Only support ScaleMode::kUNIFORM.

  • Unary: Every UnaryOperation is supported.

The size of the fused Pointwise layer is not unlimited, so some layers cannot be fused.

Fusion creates a new layer with a name consisting of both fused layers. For example, an ElementWise layer named add1 is fused with a ReLU Activation layer named relu1, creating a new layer named fusedPointwiseNode(add1, relu1).

Q/DQ Fusion#

Refer to the Explicit Quantization section for suggestions on optimizing INT8 and FP8 networks containing QuantizeLinear and DequantizeLinear layers.

Limiting Compute Resources#

Limiting the number of compute resources available to TensorRT during engine creation is beneficial when the reduced amount better represents the expected conditions during runtime. For example, when the GPU is expected to be performing additional work in parallel to the TensorRT engine or when the engine is expected to be run on a different GPU with fewer resources (note that the recommended approach is to build the engine on the GPU that will be used for inference, but this cannot always be feasible).

You can limit the number of available compute resources with the following steps:

  1. Start the CUDA MPS control daemon.

    nvidia-cuda-mps-control -d
    
  2. Set the number of computing resources to use with the CUDA_MPS_ACTIVE_THREAD_PERCENTAGE environment variable. For example, export CUDA_MPS_ACTIVE_THREAD_PERCENTAGE=50.

  3. Build the network engine.

  4. Stop the CUDA MPS control daemon.

    echo quit | nvidia-cuda-mps-control
    

The resulting engine is optimized to the reduced number of compute cores (50% in this example) and provides better throughput when using similar conditions during inference. You are encouraged to experiment with different amounts of streams and different MPS values to determine the best performance for your network.

For more details about nvidia-cuda-mps-control, refer to the nvidia-cuda-mps-control documentation and the relevant GPU requirements.

Deterministic Tactic Selection#

TensorRT runs through all the possible tactics in the engine-building phase and selects the fastest ones. Since the selection is based on the tactics’ latency measurements, TensorRT can select different tactics across different runs if some have similar latencies. As a result, different engines built from the same INetworkDefinition can behave slightly differently regarding output values and performance. You can inspect the selected tactics of an engine by using the engine inspector APIs or by turning on verbose logging while building the engine.

If deterministic tactic selection is desired, the following lists a few suggestions that can help improve the determinism of tactic selection.

Locking GPU Clock Frequency

By default, the GPU’s clock frequency is not locked, meaning that the GPU normally sits at the idle clock frequency and only boosts to the max clock frequency when there are active GPU workloads. However, there is a latency for the clock to be boosted from the idle frequency, and that can cause performance variations while TensorRT is running through the tactics and selecting the best ones, resulting in non-deterministic tactic selections.

Therefore, locking the GPU clock frequency before building a TensorRT engine can improve the determinism of tactic selection. Refer to the Hardware/Software Environment for Performance Measurements section for more information about how to lock and monitor the GPU clock and the factors that can affect GPU clock frequencies.

Increasing Average Timing Iterations

By default, TensorRT runs each tactic for at least four iterations and takes the average latency. You can increase the number of iterations by calling the setAvgTimingIterations() API:

1builderConfig->setAvgTimingIterations(8);
1Builder_config.avg_timing_iterations = 8

Increasing the number of average timing iterations can improve the determinism of tactic selections, but the required engine-building time will become longer.

Using Timing Cache

Timing Cache records the latencies of each tactic for a specific layer configuration. The tactic latencies are reused if TensorRT encounters another layer with an identical configuration. Therefore, by reusing the same timing cache across multiple engine buildings runs with the same INetworkDefinition and builder config, you can make TensorRT select an identical set of tactics in the resulting engines.