Release Notes#

These Release Note describe the key features, software enhancements and improvements, and known issues for the TensorRT release product package.

To review the TensorRT documentation for previous versions, refer to the TensorRT Archived Documentation.

To review TensorRT documentation 10.8.0 and more recent, choose a version from the bottom left navigation selector toggle.

TensorRT 10.8.0#

These are the TensorRT 10.8.0 Release Notes, which apply to x86 Linux and Windows users, and Arm-based CPU cores for Server Base System Architecture (SBSA) users on Linux. This release includes several fixes from the previous TensorRT releases and additional changes.

Announcements

  • This release supports NVIDIA Blackwell GPUs, such as the GeForce 50-series. B200 and GB200 NVL have limited support in this release and should be considered early access.

  • Starting with TensorRT 10.8, the minimum glibc version for the Linux x86_64 build is 2.28. TensorRT 10.8 is expected to be compatible with RedHat 8.x (and derivatives) and newer RedHat distributions. It is also expected to be compatible with Ubuntu 20.04 and newer Ubuntu distributions. This aligns with a similar change to the glibc minimum version by CUDA, which starts with version 12.8.

  • TensorRT CUDA 12.8 builds for all platforms are compiled using the newer GCC C++11 ABI. We don’t expect this change to be visible to customers and does not affect TensorRT API or ABI compatibility

Key Features and Enhancements

This TensorRT release includes the following key features and enhancements.

  • Added support for E2M1 FP4 data type on NVIDIA Blackwell GPUs using explicit quantization. For more information, refer to the Working with Quantized Types section.

  • Support was added for tiling optimization, which enables cross-kernel tiled inference. For more information, refer to the Tiling Optimization section.

  • A new layer type, ICumulativeLayer, was added, which computes successive reductions across an axis of a tensor.

  • A new sample, sampleEditableTimingCache, was added to demonstrate how to modify the timing cache to build an engine with the desired tactics. For more information, refer to the Create a Deterministic Build using an Editable Timing Cache section.

Breaking ABI Changes

  • There is an ABI breakage in INetworkDefinition. Applications linked against previous versions of TensorRT 10.x using INetworkDefinition APIs may not work correctly with TensorRT 10.8 unless relinked. This will be fixed in the next release, which will be ABI-compatible with 10.x builds (except 10.8).

Compatibility

Limitations

  • There is a known issue with using the markDebug API to mark multiple graph input tensors as debug tensors.

  • There are no optimized FP8 Convolutions for Group Convolutions and Depthwise Convolutions. Therefore, INT8 is still recommended for ConvNets containing these convolution ops.

  • The FP8 Convolutions only support input/output channels, which are multiples of 16. Otherwise, TensorRT will fall back to non-FP8 convolutions.

  • The FP8 Convolutions do not support kernel sizes larger than 32, such as 7x7 convolutions, and FP16 or FP32 fallback kernels will be used with suboptimal performance. Therefore, do not add FP8 Q/DQ ops before Convolutions with large kernel sizes for better performance.

  • The accumulation dtype for the batched GEMMS in the FP8 MHA must be in FP32.

    • This can be achieved by adding Cast (to FP32) ops before the batched GEMM and Cast (to FP16) after the batched GEMM.

    • Alternatively, you can convert your ONNX model using TensorRT Model Optimizer, which adds the Cast ops automatically.

  • There cannot be any pointwise operations between the first batched GEMM and the softmax inside FP8 MHAs, such as having an attention mask. This will be improved in future TensorRT releases.

  • The FP8 MHA fusions only support head sizes being multiples of 16. If the MHA has a head size that is not a multiple of 16, do not add Q/DQ ops in the MHA to fall back to the FP16 MHA for better performance.

  • On QNX, networks that are segmented into a large number of DLA loadables may fail during inference.

  • The DLA compiler can remove identity transposes but cannot fuse multiple adjacent transpose layers into a single transpose layer (likewise for reshaping). For example, given a TensorRT IShuffleLayer consisting of two non-trivial transposes and an identity reshape in between, the shuffle layer is translated into two consecutive DLA transpose layers unless the user merges the transposes manually in the model definition in advance.

  • nvinfer1::UnaryOperation::kROUND or nvinfer1::UnaryOperation::kSIGN operations of IUnaryLayer are not supported in the implicit batch mode.

  • For networks containing normalization layers, particularly if deploying with mixed precision, target the latest ONNX opset containing the corresponding function ops, such as opset 17 for LayerNormalization or opset 18 GroupNormalization. Numerical accuracy using function ops is superior to the corresponding implementation with primitive ops for normalization layers.

  • Weight streaming mainly supports GEMM-based networks like Transformers for now. Convolution-based networks may have only a few weights that can be streamed.

  • When two convolutions with INT8-QDQ and residual add share the same weight, constant weight fusion will not occur. Make a copy of the shared weight for better performance.

  • When building the nonZeroPlugin sample on Windows, you may need to modify the CUDA version specified in the BuildCustomizations paths in the vcxproj file to match the installed version of CUDA.

  • The scale factor must be a build-time constant if QuantizeLayer is used with the output FP4 data type.

  • The weights used in INT4 weights-only quantization (WoQ) cannot be refitted.

  • DynamicQuantizeLayer fails when the blocking axis is not the innermost dimension.

  • The high-precision weights used in FP4 double quantization are not refittable.

Deprecated API Lifetime

  • APIs deprecated in TensorRT 10.8 will be retained until 2/2026.

  • APIs deprecated in TensorRT 10.7 will be retained until 12/2025.

  • APIs deprecated in TensorRT 10.6 will be retained until 11/2025.

  • APIs deprecated in TensorRT 10.5 will be retained until 10/2025.

  • APIs deprecated in TensorRT 10.4 will be retained until 9/2025.

  • APIs deprecated in TensorRT 10.3 will be retained until 8/2025.

  • APIs deprecated in TensorRT 10.2 will be retained until 7/2025.

  • APIs deprecated in TensorRT 10.1 will be retained until 5/2025.

  • APIs deprecated in TensorRT 10.0 will be retained until 3/2025.

Refer to the API documentation (C++, Python) for instructions on updating your code to remove the use of deprecated features.

Deprecated and Removed Features

The following features have been deprecated or removed in TensorRT 10.8.0.

  • Deprecated the Algorithm Selector API, including:

    • IAlgorithmIOInfo

    • IAlgorithmVariant

    • IAlgorithmContext

    • IAlgorithm

    • IAlgorithmSelector

    • setAlgorithmSelector

    • getAlgorithmSelector

    Use editable mode in ITimingCache instead.

  • Deleted the sampleAlgorithmSelector sample.

  • On NVIDIA Blackwell and later platforms, TensorRT will drop cuDNN support on the following categories of plugins:

    • User-written IPluginV2Ext, IPluginV2DynamicExt, and IPluginV2IOExt plugins that are dependent on cuDNN handles provided by TensorRT (via the attachToContext() API).

    • TensorRT standard plugins that use cuDNN, specifically:

      • InstanceNormalization_TRT (versions 1, 2, and 3)

      • GroupNormalizationPlugin (version 1)

    Note

    TensorRT’s native INormalizationLayer supersedes these normalization plugins. TensorRT support for cuDNN-dependent plugins remains unchanged on pre-Blackwell platforms.

  • Deprecated INetworkDefinition::addPluginV3 and IPluginV2Layer. These are superseded by INetworkDefinition::addPluginV3 and IPluginV3Layer, respectively.

Fixed Issues

  • Errors could have happened when IConstantLayer was followed by ICastLayer. This has been fixed.

  • Addressed performance regression for TensorRT 10.x with respect to TensorRT 8.6 for networks involving data-dependent shapes, such as non-max suppression or non-zero operations.

  • Some TF32 convolution tactics may have caused a CUDA illegal memory access error if the input or output tensor had more than 2^30 elements. The workaround was to disable TF32 and use a different precision like FP32 or FP16. This issue has been fixed.

  • There were up to 30% performance gaps between the fused Multi-Head Attention (MHA) kernel built with dynamic sequence lengths versus the fused MHA kernel built with static sequence lengths if the maximum sequence length is much greater than the optimal sequence length set in the optimization profiles on Hopper GPUs. This issue has been fixed.

Known Issues

Functional

  • When running OSS demoBERT FP16 inference on H20 GPUs, different batch sizes may generate different outputs given the same input values. This can be worked around by using a fixed batch size.

  • There is a known accuracy issue running certain networks on NVIDIA HGX H20.

  • Inputs to the IRecurrenceLayer must always have the same shape. This means that ONNX models with loops whose recurrence inputs change shapes will be rejected.

  • If TensorRT 8.6 or 9.x was installed using the Python Package Index (PyPI), you cannot upgrade TensorRT to 10.x using PyPI. You must first uninstall TensorRT using pip uninstall tensorrt tensorrt-libs tensorrt-bindings, then reinstall TensorRT using pip install tensorrt. This will remove the previous TensorRT version and install the latest TensorRT 10.x. This step is required because the suffix -cuXX was added to the Python package names, which prevents the upgrade from working properly.

  • CUDA compute sanitizer may report racecheck hazards for some legacy kernels. However, related kernels do not have functional issues at runtime.

  • The compute sanitizer initcheck tool may flag false positive Uninitialized __global__ memory read errors when running TensorRT applications on NVIDIA Hopper GPUs. These errors can be safely ignored and will be fixed in an upcoming CUDA release.

  • Multihead attention fusion might not happen and affect performance if the number of heads is small.

  • An occurrence of use-after-free in NVRTC has been fixed in CUDA 12.1. When using NVRTC from CUDA 12.0 together with the TensorRT static library, you may encounter a crash in certain scenarios. Linking the NVRTC and PTXJIT compiler from CUDA 12.1 or newer will resolve this issue.

  • There are known issues reported by the Valgrind memory leak check tool when detecting potential memory leaks from TensorRT applications. The recommendation to suppress the issues is to provide a Valgrind suppression file with the following contents when running the Valgrind memory leak check tool. Add the option --keep-debuginfo=yes to the Valgrind command line to suppress these errors.

    {
       Memory leak errors with dlopen.
       Memcheck:Leak
       match-leak-kinds: definite
       ...
       fun:*dlopen*
       ...
    }
    {
        Memory leak errors with nvrtc
        Memcheck:Leak
        match-leak-kinds: definite
        fun:malloc
        obj:*libnvrtc.so*
        ...
    }
    
  • SM 7.5 and earlier devices may not have INT8 implementations for all layers with Q/DQ nodes. In this case, you will encounter a could not find any implementation error while building your engine. To resolve this, remove the Q/DQ nodes, which quantize the failing layers.

  • Installing the cuda-compat-11-4 package may interfere with CUDA-enhanced compatibility and cause TensorRT to fail even when the driver is r465. The workaround is to remove the cuda-compat-11-4 package or upgrade the driver to r470.

  • For some networks, using a batch size of 4096 may cause accuracy degradation on DLA.

  • For broadcasting elementwise layers running on DLA with GPU fallback enabled with one NxCxHxW input and one Nx1x1x1 input, there is a known accuracy issue if at least one of the inputs is consumed in kDLA_LINEAR format. It is recommended to explicitly set the input formats of such elementwise layers to different tensor formats.

  • Exclusive padding with kAVERAGE pooling is not supported.

  • Asynchronous CUDA calls are not supported in the user-defined processDebugTensor function for the debug tensor feature due to a bug in Windows 10.

  • inplace_add mini-sample of the quickly_deployable_plugins Python sample may produce incorrect outputs on Windows. This will be fixed in a future release.

  • When linking with libcudart_static.a using a RedHat gcc-toolset-11 or earlier compiler, you may encounter an issue where exception handling isn’t working. When a throw or exception happens, the catch is ignored, and an abort is raised, killing the program. This may be related to a linker bug causing the eh_frame_hdr ELF segment to be empty. You can workaround this issue using a new linker, such as the one from gcc-toolset-13.

  • On the NVIDIA Blackwell platform, the format attribute in the PluginTensorDesc parameters of PluginV3’s onShapeChange and enqueue function is kLinear(0); however, the tensor’s physical format is correct (as configured in configurePlugin and setTactic).

  • Exceptions thrown from PluginV3’s enqueue on the NVIDIA Blackwell platform may escape the exception-handling routine. The workaround is to use a try-catch block as a WAR if any related abort or crash occurs.

  • TensorRT may exit if inputs with invalid values are provided to the RoiAlign plugin (ROIAlign_TRT), especially if there is inconsistency in the indices specified in the batch_indices input and the actual batch size used.

  • The engine build process may encounter failures when handling ScatterND operations with empty indices.

  • The Valgrind Memcheck tool may report memory leaks when TensorRT builds the engine on pre-Blackwell GPUs, especially if the model contains convolution layers.

  • The thread sanitizer tool may report data races of CPU threads when TensorRT is building the engine.

  • On the NVIDIA Blackwell platform, engine build may fail when a tensor with a data-dependent shape is passed into an IShuffleLayer to be reshaped, an ISliceLayer with dynamic axes, an IGatherLayer, or a layer where the data-dependent dimensions of the tensor would be subject to arithmetic or logical operations.

  • Dynamic quantization with a non-innermost block axis is not supported.

  • The sampleEditableTimingCache sample may not compile on SLES 15 when the default GCC version is 7.5.0. This is due to a missing header, <charconv>, required for complete C++17 support. This issue can be resolved by installing GCC 11 and switching your default compiler to GCC 11. The following commands, run as root, may assist you with changing the default GCC version.

    # install gcc11 development tools
    zypper install gcc11 gcc11-c++
    
    # add gcc11 as an alternative compiler
    update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 60
    update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-11 60
    
    # change the default compiler to gcc11
    update-alternatives --config gcc
    update-alternatives --config g++
    
  • Convolution/Deconvolution does not support non-zero spatial output dimensions with corresponding zero input dimensions. TensorRT cannot detect this issue. When the model contains such Convolution/Deconvolution operations, users might see a crash without a corresponding error message.

Performance

  • There is an increase in GPU memory due to fusions breaking between some GEMMs and other layers. We will further evaluate whether this increase truly benefits performance in TensorRT 10.9.

  • There is known CPU peak memory usage regression with the roberta_base engine on Ampere GPUs compared to TensorRT 10.7.

  • Up to 10% inference performance regression for VIT, the GEMMs selection doesn’t find the best kernel. The workaround is to add more tactics with setMaxNbTactics number to 200, with trtexec can WAR by adding --maxTactics=200.

  • Up to 37% inference performance regression for the cortanaasr_s128_bunk_e128 network on Hopper precision GPUs compared to TensorRT 10.7 in CUDA 12.8 environment.

  • Up to 100 MB context memory size regression compared to TensorRT 8.6 on Hopper GPUs for CRNN (Convolutional Recurrent Neural Network) models. Inference performance is not affected.

  • Up to 9% inference performance regression for StableDiffusion v2.0/2.1 VAE network in FP16 precision on Hopper GPUs compared to TensorRT 10.6 in CUDA 11.8 environment. This issue can be fixed by upgrading CUDA to 12.6.

  • Up to 60% performance regression compared to TensorRT 8.6 on Ampere GPUs for group convolutions with N channels per group, where N is not a power of 2. This can be worked around by padding N to the next power of 2

  • Up to 22% context memory size regression for HiFi-GAN networks in INT8 precision compared to TensorRT 10.5 on Ampere GPUs.

  • Up to 7% performance regression for Megatron networks in FP16 precision compared to TensorRT 10.6 for BS1 and Seq128 on H100 GPUs.

  • Up to 10% performance regression for BERT networks exported from TensorFlow2 in FP16 precision compared to TensorRT 10.4 for BS1 and Seq128 on A16 GPUs.

  • Up to 16% regression in context memory usage for StableDiffusion XL VAE network in FP8 precision on H100 GPUs compared to TensorRT 10.3 due to a necessary functional fix.

  • Up to 15% regressing in context memory usage for networks containing InstanceNorm and Activation ops compared to TensorRT 10.0.

  • Up to 15% CPU memory usage regression for mbart-cnn/mamba-370m in FP16 precision and OOTB mode on NVIDIA Ada Lovelace GPUs compared to TensorRT 10.2.

  • Up to 6% performance regression for BERT/Megatron networks in FP16 precision compared to TensorRT 10.2 for BS1 and Seq128 on H100 GPUs.

  • Up to 6% performance regression for Bidirectional LSTM in FP16 precision on H100 GPUs compared to TensorRT 10.2.

  • Up to 25% performance regression when running TensorRT-LLM without the attention plugin. The current recommendation is always to enable the attention plugin when using TensorRT-LLM.

  • There are known performance gaps between engines built with REFIT enabled and engines built with REFIT disabled.

  • Up to 60 MB engine size fluctuations for the BERT-Large INT8-QDQ model on Orin due to unstable tactic selection among tactics.

  • Up to 16% performance regression for BasicUNet, DynUNet, and HighResNet in INT8 precision compared to TensorRT 9.3.

  • Up to 40-second increase in engine building for BART networks on NVIDIA Hopper GPUs.

  • Up to 20-second increase in engine building for some large language models (LLMs) on NVIDIA Ampere GPUs.

  • Up to 2.5x build time increase compared to TensorRT 9.0 for certain Bert-like models due to additional tactics available for evaluation.

  • Up to 13% performance drop for the CortanaASR model on NVIDIA Ampere GPUs compared to TensorRT 8.5.

  • Up to 18% performance drop for the ShuffleNet model on A30/A40 compared to TensorRT 8.5.1.

  • Convolution on a tensor with an implicitly data-dependent shape may run significantly slower than on other tensors of the same size. Refer to the Glossary for the definition of implicitly data-dependent shapes.

  • Up to 5% performance drop for networks using sparsity in FP16 precision.

  • Up to 6% performance regression compared to TensorRT 8.5 on OpenRoadNet in FP16 precision on NVIDIA A10 GPUs.

  • Up to 70% performance regression compared to TensorRT 8.6 on BERT networks in INT8 precision with FP16 disabled on L4 GPUs. Enable FP16 and disable INT8 in the builder config to work around this.

  • In explicitly quantized networks, a group convolution with a Q/DQ pair before but no Q/DQ pair after is expected to run with INT8-IN-FP32-OUT mixed precision. However, NVIDIA Hopper may fall back to FP32-IN-FP32-OUT if the input channel count is small.

  • The kREFIT and kREFIT_IDENTICAL have performance regressions compared with non-refit engines where convolution layers are present within a branch or loop, and the precision is FP16/INT8. This issue will be addressed in future releases.