Best Practices#

Performance Benchmarking using tensorrt_rtx#

This section introduces how to use tensorrt_rtx, a command-line tool designed for TensorRT-RTX performance benchmarking, to get the inference performance measurements of your deep learning models. If you manually installed TensorRT-RTX, tensorrt_rtx is part of the installation.

Performance Benchmarking with an ONNX File#

If your model is already in the ONNX format, the tensorrt_rtx tool can measure its performance directly. In this example, we will use the ResNet-50 v1 ONNX model from the ONNX model zoo to showcase how to use tensorrt_rtx to measure its performance.

For example, the tensorrt_rtx command to measure the performance of ResNet-50 with batch size 4 is:

tensorrt_rtx --onnx=resnet50-v1-12.onnx --shapes=data:4x3x224x224 --noDataTransfers --useCudaGraph --useSpinWait

Where:

  • The --onnx flag specifies the path to the ONNX file.

  • The --shapes flag specifies the input tensor shapes.

  • The other flags have been added to make performance results more stable.

The value for the --shapes flag is in the format of name1:shape1,name2:shape2,... You can get the input tensor names and shape profiles by visualizing the ONNX model using tools like Netron or by running a Polygraphy model inspection.

After running the tensorrt_rtx command, tensorrt_rtx will parse your ONNX file, build a TensorRT-RTX plan file, measure the performance of this plan file, and then print a performance summary. This summary includes many performance metrics, but the most important are throughput and median latency.

Per-Layer Runtime and Layer Information#

In previous sections, we described using tensorrt_rtx to measure the end-to-end latency. This section will show an example of per-layer runtime and per-layer information using tensorrt_rtx. This will help you determine how much latency each layer contributes to the end-to-end latency and identify any performance bottlenecks.

This is an example tensorrt_rtx command to print per-layer runtime and per-layer information using the quantized ResNet-50 ONNX model:

tensorrt_rtx --onnx=resnet50-v1-12-quantized.onnx --shapes=data:4x3x224x224 --noDataTransfers --useSpinWait --profilingVerbosity=detailed --dumpLayerInfo --verbose --dumpProfile --separateProfileRun

Where:

  • The --profilingVerbosity=detailed flag enables detailed layer information capturing.

  • The --dumpLayerInfo flag shows the per-layer information in the log.

  • The --dumpProfile and --separateProfileRun flags show the per-layer runtime latencies in the log.

Duration and Number of Iterations#

By default, tensorrt_rtx warms up for at least 200 ms and runs inference for at least 10 iterations or at least 3 seconds, whichever is longer. You can modify these parameters by adding the --warmUp=500, --iterations=100, and --duration=60 flags, which mean running the warm-up for at least 500 ms and running the inference for at least 100 iterations or at least 60 seconds, whichever is longer.

You can run tensorrt_rtx --help for a detailed explanation about other tensorrt_rtx flags.

CUDA Profiling Tools#

The recommended CUDA profiler is NVIDIA Nsight Systems. This profiler can be used on any CUDA program to report timing information about the kernels launched during execution, data movement between host and device, and the CUDA API calls used.

Nsight Systems can be configured to report timing information for only a portion of the program’s execution or to report traditional CPU sampling profile information and GPU information.

The basic usage of Nsight Systems is first to run the command nsys profile -o <OUTPUT> <INFERENCE_COMMAND>, then open the generated <OUTPUT>.nsys-rep file in the Nsight Systems GUI to visualize the captured profiling results.

Using the NVTX Tracing in Nsight Systems#

Tracing enables the NVIDIA Tools Extension SDK (NVTX), a C-based API for marking events and ranges in your applications. It allows Nsight Compute and Nsight Systems to collect data generated by TensorRT-RTX applications.

Decoding the kernel names back to layers in the original network can be complicated. Because of this, TensorRT-RTX uses NVTX to mark a range for each layer, allowing the CUDA profilers to correlate each layer with the kernels called to implement it. In TensorRT-RTX, NVTX helps to correlate the runtime engine layer execution with CUDA kernel calls. Nsight Systems supports collecting and visualizing these events and ranges on the timeline. Nsight Compute also supports collecting and displaying the state of all active NVTX domains and ranges in a given thread when the application is suspended.

In TensorRT-RTX, each layer may launch one or more kernels to perform its operations. The exact kernels launched depend on the optimized network and the hardware present. Depending on the builder’s choices, multiple additional operations that reorder data may be interspersed with layer computations; these reformat operations may be implemented as device-to-device memory copies or custom kernels.

Controlling the Level of Details in NVTX Tracing#

By default, TensorRT-RTX only shows layer names in the NVTX markers. At the same time, users can control the level of details by setting the ProfilingVerbosity in the IBuilderConfig when the engine is built. For example, to disable NVTX tracing, set the ProfilingVerbosity to kNONE:

1builderConfig->setProfilingVerbosity(ProfilingVerbosity::kNONE);
1import tensorrt_rtx as trt
2builder_config.profiling_verbosity = trt.ProfilingVerbosity.NONE

On the other hand, you can choose to allow TensorRT-RTX to print more detailed layer information in the NVTX markers, including input and output dimensions, operations, parameters, tactic numbers, and so on, by setting the ProfilingVerbosity to kDETAILED:

1builderConfig->setProfilingVerbosity(ProfilingVerbosity::kDETAILED);
1import tensorrt_rtx as trt
2builder_config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED

Note

Enabling detailed NVTX markers increases the latency of enqueueV3() calls and could result in a performance drop if the performance depends on the latency of enqueueV3() calls.

Running Nsight Systems with tensorrt_rtx#

The following example code shows how to gather Nsight Systems profiles using the tensorrt_rtx tool:

tensorrt_rtx --onnx=foo.onnx --profilingVerbosity=detailed --saveEngine=foo.plan
nsys profile -o foo_profile --capture-range cudaProfilerApi tensorrt_rtx --profilingVerbosity=detailed --loadEngine=foo.plan --warmUp=0 --duration=0 --iterations=50

The first command builds and serializes the engine to foo.plan, and the second command runs the inference using foo.plan and generates a foo_profile.nsys-rep file which can then be opened in the Nsight Systems user interface for visualization.

The --profilingVerbosity=detailed flag allows TensorRT-RTX to show more detailed layer information in the NVTX marking, and the --warmUp=0, --duration=0, and --iterations=50 flags allow you to control how many inference iterations to run. By default, tensorrt_rtx runs inference for three seconds, which may result in a large output of the nsys-rep file.

Optional: Enabling GPU Metrics Sampling in Nsight Systems#

On discrete GPU systems, add the --gpu-metrics-device all flag to the nsys command to sample GPU metrics, including GPU clock frequencies, DRAM bandwidth, Tensor Core utilization, and so on. If the flag is added, these GPU metrics appear in the Nsight Systems web interface.

Hardware and Software Environment for Performance Measurements#

Performance measurements are influenced by many factors, including hardware environment differences like the machine’s cooling capability and software environment differences like GPU clock settings. This section summarizes a few items that may affect performance measurements.

GPU Monitoring#

While measuring performance, it is recommended that you record and monitor the GPU status in parallel to the inference workload. Having the monitoring data allows you to identify possible root causes when you see unexpected performance measurement results.

Before the inference starts, issue the nvidia-smi -q command to get detailed information on the GPU, including the product name, power cap, clock settings, and so on. Then, while the inference workload is running, run the nvidia-smi dmon -s pcu -f <FILE> -c <COUNT> command in parallel to print out GPU clock frequencies, power consumption, temperature, and utilization to a file. Issue nvidia-smi dmon --help for more options about the nvidia-smi device monitoring tool.

GPU Clock Locking#

By default, the GPU clock frequency is floating, meaning it sits idle when there is no active workload and boosts the boost clock frequency when the workload starts. This is usually the desired behavior since it allows the GPU to generate less heat at idle and to run at maximum speed when there is an active workload.

Alternatively, you can lock the clock at a specific frequency by calling the nvidia-smi -lgc <freq> command (and conversely, you can let the clock float again with the nvidia-smi -rgc command).

Note

On Linux, sudo may be required as a prefix to the command, and on Windows, the command prompt may require administrator privileges.

The nvidia-smi -q -d SUPPORTED_CLOCKS command can find the supported clock frequencies. After the clock frequency is locked, it should stay at that frequency unless power or thermal throttling occurs, which will be explained in the next sections. When the throttling kicks in, the device behaves like the clock floats.

Running TensorRT-RTX workloads with floating clocks or with throttling taking place can lead to unstable performance measurements across inferences because every CUDA kernel may run at slightly different clock frequencies, depending on which frequency the driver boosts or throttles the clock to at that moment. On the other hand, running TensorRT-RTX workloads with locked clocks allows more consistent performance measurements. Still, the average performance will not be as good as when the clock is floating or is locked at maximum frequency with throttling taking place.

There is no definite recommendation on whether the clock should be locked or which clock frequency to lock the GPU while running TensorRT-RTX workloads. It depends on whether the stable performance or the best average performance is desired.

GPU Power Throttling#

Power throttling occurs when the average GPU power consumption reaches the power limit, which can be set by the sudo nvidia-smi -pl <power_cap> command. When this happens, the driver has to throttle the clock to a lower frequency to keep the average power consumption below the limit. The constantly changing clock frequencies may lead to unstable performance measurements if the measurements are taken within a short time, such as within 20ms.

Power throttling happens by design and is a natural phenomenon when the GPU clock is not locked or is locked at a higher frequency, especially for GPUs with lower power limits. To avoid performance variations caused by power throttling, you can lock the GPU clock at a lower frequency to stabilize the performance numbers. However, the average performance numbers will be lower than those with floating clocks or the clock locked at a higher frequency, even though power throttling would happen in this case.

Another issue with power throttling is that it may skew the performance numbers if there are gaps between inferences in your performance benchmarking applications. For example, if the application synchronizes at each inference, there will be periods when the GPU is idle between the inferences. The gaps cause the GPU to consume less power on average, so the clock is throttled less, and the GPU can run at higher clock frequencies on average. However, the throughput numbers measured this way are inaccurate because when the GPU is fully loaded with no gaps between inferences, the actual clock frequency will be lower, and the actual throughput will not reach the throughput numbers measured using the benchmarking application.

To avoid this, the tensorrt_rtx tool is designed to maximize GPU execution by leaving nearly no gaps between GPU kernel executions so that it can measure the true throughput of a TensorRT-RTX workload. Therefore, if you see performance gaps between your benchmarking application and what tensorrt_rtx reports, check if the power throttling and the gaps between inferences are the cause.

Lastly, power consumption can depend on the activation values, causing different input performance measurements. For example, if all the network input values are set to zeros or NaNs, the GPU consumes less power than the inputs are normal values because of fewer bit-flips in DRAM and the L2 cache. To avoid this discrepancy, always use the input values that best represent the actual value distribution when measuring the performance. The tensorrt_rtx tool uses random input values by default, but you can specify the input using the --loadInputs flag.

GPU Thermal Throttling#

Thermal throttling happens when the GPU temperature reaches a predefined threshold (around 85 degrees Celsius for most GPUs) and the driver throttles the clock to a lower frequency to prevent the GPU from overheating. You can identify this by observing the temperature logged by the nvidia-smi dmon command gradually increase while the inference workload runs until it approaches the temperature threshold and the clock frequency drops. Thermal throttling can be prevented or mitigated by improving cooling for the GPU.

Note that a higher GPU temperature also leads to more leakage current in the circuits, which increases the power consumed by the GPU at a specific clock frequency. Therefore, for GPUs more likely to be power throttled, poor cooling can lead to lower stabilized clock frequency with power throttling and, thus, worse performance, even if the GPU clocks have not been thermally throttled yet.

Synchronization Modes#

If performance is measured with cudaStreamSynchronize() or cudaEventSynchronize(), synchronization overhead variations may lead to performance measurement variations. This section describes the causes of the variations and how to avoid them.

When cudaStreamSynchronize() is called, there are two ways in which the driver waits until the stream is completed. If the cudaDeviceScheduleBlockingSync flag has been set with cudaSetDeviceFlags() calls, then the cudaStreamSynchornize() uses the blocking-sync mechanism. Otherwise, it uses the spin-wait mechanism.

A similar idea applies to CUDA events. If a CUDA event is created with the cudaEventDefault flag, then the cudaEventSynchronize() call uses the spin-wait mechanism. If a CUDA event is created with the cudaEventBlockingSync flag, then the cudaEventSynchronize() call will use the blocking-sync mechanism.

When the blocking-sync mode is used, the host thread yields to another thread until the device work is done. This allows the CPUs to sit idle to save power or to be used by other CPU workloads when the device is still executing. However, the blocking-sync mode tends to result in relatively unstable overheads in stream/event synchronizations in some OS, leading to variations in latency measurements.

On the other hand, when the spin-wait mode is used, the host thread is constantly polling until the device work is done. Using spin-wait makes the latency measurements more stable due to shorter and more stable overhead in stream/event synchronizations. However it also introduces additional CPU overhead. Therefore, if you want to reduce CPU power consumption or do not want the stream/event synchronizations to consume CPU resources (for example, you are running other heavy CPU workloads in parallel), use the blocking-sync mode. If you care more about stable performance measurements, use the spin-wait mode.

In tensorrt_rtx, the default synchronization mechanism is in blocking-sync mode. Add the --useSpinWait flag to enable synchronizations using the spin-wait mode for more stable latency measurements at the cost of more CPU utilizations and power consumptions.

Optimizing TensorRT-RTX Performance#

The following sections focus on the general inference flow on GPUs and some general strategies to improve performance.

Network Precision#

TensorRT-RTX networks are strongly typed, meaning that operations will execute at the precision specified in the source model. Inference will generally perform best at lower precisions. Whenever possible use lower precisions to take advantage of improved performance, keeping in mind hardware-specific support for some datatypes:

  • FP4 is only supported on NVIDIA Blackwell (SM120) and above

  • FP8 is only supported on NVIDIA Ada (SM89) and above

  • BF16 and INT4 are only supported on NVIDIA Ampere (SM80) and above

  • FP16, FP32, and other integer types are supported on all supported hardware

To enjoy the additional performance benefit of lower-precision types via quantization, Quantize/Dequantize operations need to be inserted into the ONNX model to tell TensorRT-RTX where to quantize/dequantize the tensors and what scaling factors to use.

For more information, refer to the Working with Quantized Types section.

Dynamic Shape Specialization#

TensorRT-RTX supports networks with dynamic shapes. Dynamic shapes allow applications to defer specifying some or all tensor dimensions, for network inputs and outputs, until runtime.

A key advantage of the dynamic shapes implementation in TensorRT-RTX is that shape-specialized kernels are generated and compiled at runtime based on actual usage. When a new shape is passed to the network, new kernels optimized for this shape will be generated in a background thread during inference. This means that when profiling performance with a new shape, it is important to use warmup runs to allow the shape-specialized kernels to compile or compile using the eager DynamicShapesKernelSpecializationStrategy to skip use of the fallback kernels. Applications can also use warmup runs to compile specialized kernels for expected shapes at startup, avoiding reduced performance the first time the shape is seen.

For more information, refer to the Working with Dynamic Shapes section.

Runtime Caching#

At runtime, TensorRT-RTX compiles hardware-specific kernels using Just-In-Time (JIT) compilation. Runtime caching stores the compiled kernels, allowing future invocations, potentially even on a different system, to avoid repeated recompilation of the kernels. This can reduce application startup cost and enable peak performance out-of-the-box.

1// Create a runtime cache.
2auto runtimeCache = std::unique_ptr<nvinfer1::IRuntimeCache>(runtimeConfig->createRuntimeCache());
3
4// Set the runtime cache in runtime configuration.
5runtimeConfig->setRuntimeCache(*runtimeCache);
1# Create a runtime cache.
2runtimeCache = runtimeConfig.create_runtime_cache()
3
4# Set the runtime cache in runtime configuration.
5runtimeConfig.set_runtime_cache(runtimeCache)

The execution context will use the attached runtime cache for all inference executions, and JIT compiled kernels are added to the cache. Optionally the cache can be serialized for persistence across application invocations and installations.

1// Serialize the runtime cache.
2auto serializedRuntimeCache = std::unique_ptr<nvinfer1::IHostMemory>(runtimeCache->serialize());
1# Serialize the runtime cache.
2serializedRuntimeCache = runtimeCache.serialize()

For more information, refer to the Working with Runtime Cache section.

Batching#

In TensorRT-RTX, a batch is a collection of inputs that can all be processed uniformly. Each instance in the batch has the same shape and flows through the network similarly. Therefore, each instance can be trivially computed in parallel.

Each network layer will have some overhead and synchronization required to compute forward inference. By computing more results in parallel, this overhead is paid off more efficiently. In addition, many layers are performance-limited by the smallest dimension in the input. If the batch size is one or small, this size can often be the performance-limiting dimension.

On NVIDIA Ada Lovelace or later GPUs, decreasing the batch size may improve the throughput significantly if the smaller batch sizes help the GPU cache the input/output values in the L2 cache. Therefore, various batch sizes should be tried to find the batch size that provides optimal performance.

Within-Inference Multi-Streaming#

In general, CUDA programming streams are a way of organizing asynchronous work. Asynchronous commands put into a stream are guaranteed to run in sequence but may execute out of order concerning other streams. In particular, asynchronous commands in two streams may be scheduled to run concurrently (subject to hardware limitations).

In the context of TensorRT-RTX and inference, each layer of the optimized final network will require work on the GPU. However, not all layers can fully use the hardware’s computation capabilities. Scheduling requests in separate streams allows work to be scheduled immediately as the hardware becomes available without unnecessary synchronization. Even if only some layers can be overlapped, overall performance will improve.

Use the IBuilderConfig::setMaxAuxStreams() API to set the maximum number of auxiliary streams TensorRT-RTX can use to run multiple layers in parallel. The auxiliary streams contrast the “mainstream” provided in the enqueueV3() call. If enabled, TensorRT-RTX will run some layers on the auxiliary streams parallel to those running on the mainstream.

For example, to run the inference on at most eight streams (that is, seven auxiliary streams and one mainstream) in total:

1config->setMaxAuxStreams(7)
1config.max_aux_streams = 7

Note that this only sets the maximum number of auxiliary streams. However, TensorRT-RTX may use fewer auxiliary streams than this number if it determines that using more streams does not help.

To get the actual number of auxiliary streams that TensorRT-RTX uses for an engine, run the following:

1int32_t nbAuxStreams = engine->getNbAuxStreams()
1num_aux_streams = engine.num_aux_streams

When an execution context is created from the engine, TensorRT-RTX automatically creates the auxiliary streams needed to run the inference. However, you can also specify the auxiliary streams you would like TensorRT-RTX to use:

1int32_t nbAuxStreams = engine->getNbAuxStreams();
2std::vector<cudaStream_t> streams(nbAuxStreams);
3for (int32_t i = 0; i < nbAuxStreams; ++i)
4{
5    cudaStreamCreate(&streams[i]);
6}
7context->setAuxStreams(streams.data(), nbAuxStreams);
1from cuda import cudart
2num_aux_streams = engine.num_aux_streams
3streams = []
4for i in range(num_aux_streams):
5    err, stream = cudart.cudaStreamCreate()
6    streams.append(stream)
7context.set_aux_streams(streams)

TensorRT-RTX will always insert event synchronizations between the main stream provided using enqueueV3() call and the auxiliary streams:

  • At the beginning of the enqueueV3() call, TensorRT-RTX will ensure that all the auxiliary streams wait on the activities on the mainstream.

  • At the end of the enqueueV3() call, TensorRT-RTX will ensure that the mainstream waits for the activities on all the auxiliary streams.

Enabling auxiliary streams may increase memory consumption because some activation buffers can no longer be reused.

CUDA Graphs#

CUDA Graphs represent a sequence (or, more generally, a graph) of kernels in a way that allows CUDA to optimize their scheduling. This can be particularly useful when enqueue and launch times take longer than the actual GPU executions, causing the latency of enqueueV3() calls to become the performance bottleneck. We say that this type of workload is “enqueue-bound.”

This can occur if the workload is very small in terms of the number of computations, such as containing convolutions with small I/O sizes, matrix multiplications with small GEMM sizes, or mostly element-wise operations throughout the network, then the workload tends to be enqueue-bound. This is because most CUDA kernels take the CPU and the driver around 5-15 microseconds to launch per kernel, so if each CUDA kernel execution time is only a few microseconds long on average, the kernel launch time becomes the main performance bottleneck.

In tensorrt_rtx, you can tell that a workload is enqueue-bound if the reported Enqueue Time is close to or longer than the reported GPU Compute Time. In this case, it is recommended that you use CUDA Graphs to run the network.

To solve this, you can increase the computations in each CUDA kernel by increasing the batch size, or you can use CUDA Graphs to capture the kernel launches into a graph and launch the graph instead of calling enqueueV3().

Using CUDA Graphs with TensorRT-RTX Execution Context#

TensorRT-RTX’s enqueueV3() method supports CUDA graph capture for models requiring no mid-pipeline CPU interaction. For example:

 1// Call enqueueV3() once after an input shape change to update internal state.
 2context->enqueueV3(stream);
 3
 4// Capture a CUDA graph instance
 5cudaGraph_t graph;
 6cudaGraphExec_t instance;
 7cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
 8context->enqueueV3(stream);
 9cudaStreamEndCapture(stream, &graph);
10cudaGraphInstantiate(&instance, graph, 0);
11
12// To run inferences, launch the graph instead of calling enqueueV3().
13for (int i = 0; i < iterations; ++i) {
14    cudaGraphLaunch(instance, stream);
15    cudaStreamSynchronize(stream);
16}
 1from cuda import cudart
 2err, stream = cudart.cudaStreamCreate()
 3
 4# Call execute_async_v3() once after an input shape change to update internal state.
 5context.execute_async_v3(stream);
 6
 7# Capture a CUDA graph instance
 8cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureModeGlobal)
 9context.execute_async_v3(stream)
10err, graph = cudart.cudaStreamEndCapture(stream)
11err, instance = cudart.cudaGraphInstantiate(graph, 0)
12
13# To run inferences, launch the graph instead of calling execute_async_v3().
14for i in range(iterations):
15    cudart.cudaGraphLaunch(instance, stream)
16    cudart.cudaStreamSynchronize(stream)

Limitations of CUDA Graphs#

CUDA graphs cannot handle some operations, so graph capture may fail if the execution context contains such operations. Typical deep learning operators unsupported by CUDA graphs include loops, conditionals, and layers requiring data-dependent shapes. In these cases, cudaStreamEndCapture() will return cudaErrorStreamCapture* errors, indicating that the graph capturing has failed, but the context can continue to be used for normal inference without CUDA graphs. Refer to the CUDA Programming Guide to learn more about the limitations of CUDA graphs.

When capturing a graph which uses dynamic shapes, it is important to account for shape-specialized kernel compilation to avoid capturing less performant fallback kernels. Applications can use eager kernel specialization to force execution using specialized kernels rather than waiting for parallel compilation to complete. When using tensorrt_rtx, the flag --specializeStrategyDS=eager forces eager compilation. Applications can also set this behavior using the API:

1runtimeConfig->setDynamicShapesKernelSpecializationStrategy(nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER);
1import tensorrt_rtx as trt
2runtime_config.dynamic_shapes_kernel_specialization_strategy = trt.DynamicShapesKernelSpecializationStrategy.EAGER

It is also important to account for the two-phase execution strategy used in the presence of dynamic shapes.

  1. Update the model’s internal state to account for any changes in input size.

  2. Stream work to the GPU, compile shape-specialized kernels in parallel.

The first phase requires no per-invocation work for models where input size is fixed at build time. Otherwise, if the input sizes have changed since the last invocation, some work may be required to update derived properties.

The first phase of work is not designed to be captured, and even if the capture is successful, it is likely to increase model execution time. Therefore, after changing the shapes of inputs or the values of shape tensors, call enqueueV3() to warm up the graph before capturing.

Graphs captured with TensorRT-RTX are specific to the input size and the state of the execution context. Modifying the context from which the graph was captured will result in undefined behavior when executing the graph. In particular, if the application is managing memory for activations using create_execution_context(tensorrt_rtx.ExecutionContextAllocationStrategy(2)), the memory address is also captured as part of the graph. Locations of input and output buffers are also captured as part of the graph.

Therefore, the best practice is to use one execution context per captured graph and to share memory across the contexts with create_execution_context(tensorrt_rtx.ExecutionContextAllocationStrategy(2)).

Concurrent CUDA Activities with CUDA Graph Capture#

Launching a CUDA kernel on the CUDA legacy default stream or calling synchronous CUDA APIs like cudaMemcpy() while capturing a CUDA graph fails because these CUDA activities implicitly synchronize the CUDA streams used by TensorRT-RTX execution contexts.

To avoid breaking the CUDA graph capture, ensure other CUDA kernels are launched on non-default CUDA streams and use the asynchronous version of CUDA APIs, like cudaMemcpyAsync().

Alternatively, a CUDA stream can be created with the cudaStreamNonBlocking flag to capture the CUDA graph for an execution context. If the execution context uses auxiliary streams, make sure you also call the setAuxStreams() API using streams created with the cudaStreamNonBlocking flag.