Working with Quantized Types#

Introduction to Quantization#

TensorRT supports the use of low-precision types to represent quantized floating point values. The quantization scheme is symmetric quantization—quantized values are represented in signed INT8, FP8E4M3 (FP8 for short), signed INT4, or FP4E2M1 (FP4 for short), and the transformation from quantized to unquantized values is simply a multiplication. In the reverse direction, quantization uses the reciprocal scale, followed by clamping and rounding (for integers) or casting (for FP8 and FP4).

TensorRT quantizes activations and weights for INT8, FP8, and FP4. Weight-only quantization is supported for INT4.

Quantization Workflows#

TensorRT Model Optimizer is a library that helps produce QAT models that TensorRT can optimize. The toolkit’s PTQ recipe can also perform PTQ in both PyTorch and ONNX models.

There are two workflows for creating quantized networks:

Post-training quantization (PTQ)

Derives scale factors after training the network. TensorRT provides a workflow for PTQ called calibration. It measures the distribution of activations within each activation tensor as the network executes on representative input data and then uses that distribution to estimate scale values for each tensor.

Quantization-aware training (QAT)

Computes the scale factors during training using fake-quantization, simulating the quantization and dequantization processes. This allows the training process to compensate for the effects of the quantization and dequantization operations.

Explicit vs Implicit Quantization#

Note

Implicit quantization is deprecated. You should use the TensorRT Model Optimizer to create models with explicit quantization.

Quantized networks can be processed in two (mutually exclusive) ways: using either implicit or explicit quantization. The main difference between the two processing modes is whether you require explicit control over quantization or let the TensorRT builder choose which operations and tensors to quantize (implicit). The sections below provide more details. Implicit quantization is only supported when quantizing for INT8. It cannot be used together with strong typing (because types are not auto-tuned, and the only method to convert activations to and from INT8 is via Quantize (Q) and Dequantize (DQ) operators).

TensorRT uses explicit quantization mode when a network has QuantizeLayer and DequantizeLayer layers. TensorRT uses implicit quantization mode when there are no QuantizeLayer or DequantizeLayer layers in the network, and INT8 is enabled in the builder configuration. Only INT8 is supported in implicit quantization mode.

In implicitly quantized networks, each activation tensor candidate for quantization has an associated scale deduced by a calibration process or assigned by the API function setDynamicRange. TensorRT will use this scale if it decides to quantize the tensor.

When processing implicitly quantized networks, TensorRT treats the model as a floating-point model when applying the graph optimizations and uses INT8 opportunistically to optimize layer execution time. If a layer runs faster in INT8 and has assigned quantization scales on its data inputs and outputs, then a kernel with INT8 precision is assigned to that layer. Otherwise, a high-precision floating-point (FP32, FP16, or BF16) kernel is assigned. Where a high-precision floating point is required for accuracy at the expense of performance, this can be specified using the APIs Layer::setOutputType and Layer::setPrecision.

In explicitly quantized networks, the quantization and dequantization operations are represented explicitly by IQuantizeLayer (C++, Python) and IDequantizeLayer (C++, Python) nodes in the graph - these will henceforth be referred to as Q/DQ nodes. By contrast with implicit quantization, the explicit form specifies exactly where conversion to and from a quantized type is performed, and the optimizer will perform only conversions to and from quantized types that are dictated by the semantics of the model, even if:

  • Adding extra conversions could increase layer precision (for example, choosing an FP16 kernel implementation over a quantized type implementation).

  • Adding or removing conversions results in a faster engine (for example, choosing a quantized type kernel implementation to execute a layer specified as having high precision or vice versa).

ONNX uses an explicitly quantized representation: when a model in PyTorch or TensorFlow is exported to ONNX, each fake-quantization operation in the framework’s graph is exported as Q, followed by DQ. Since TensorRT preserves the semantics of these layers, users can expect accuracy that is very close to that seen in the framework. While optimizations preserve the arithmetic semantics of quantization and dequantization operators, they may change the order of floating-point operations in the model, so results will not be bitwise identical.

TensorRT’s PTQ capability generates a calibration cache with implicit quantization. By contrast, performing either QAT or PTQ in a deep learning framework and then exporting to ONNX will result in an explicitly quantized model.

Implicit vs Explicit Quantization#

Implicit Quantization (Deprecated)

Explicit Quantization

Supported quantized data-types

INT8

INT8, FP8, INT4, FP4

User control over precision

Global builder flags and per-layer precision APIs.

Encoded directly in the model.

API

  • Model + Scales (dynamic range API)

  • Model + Calibration data

Model with Q/DQ layers.

Quantization scales

  • Weights:
    • Set by TensorRT (internal)

    • Per-channel quantization

    • INT8 range [-127, 127]

  • Activations:
    • Set by calibration or specified by the user

    • Per-tensor quantization

    • INT8 range [-128, 127]

  • Weights and activations:
    • Specified using Q/DQ ONNX operators

    • INT8 range [-128, 127]

    • FP8 range: [-448, 448]

    • INT4 range: [-8, 7]

    • FP4 range: [-6, 6]

  • Activations use per-tensor quantization.

  • Weights use either per-tensor quantization, per-channel quantization, or block quantization.

For more background on quantization, refer to the following papers:

Quantization Schemes#

Given scale \(\text{s}\), we can represent the quantization and dequantization operations as follows:

\(x_{q}=quantize\left(x, s \right)=roundWithTiesToEven(clip(\frac{x}{s}, -128,127))\)

Where:

  • \(\text{x}\) is a high-precision floating point value to be quantized.

  • \(x_{q}\) is a quantized INT8 value in range [-128,127]. For more information, refer to the Explicit vs Implicit Quantization section.

  • \(\text{roundWithTiesToEven}\) is described here.

\(\text{x}=dequantize\left(x_{q}, s\right)=x_{q}\ast s\)

In explicit quantization, you are responsible for choosing all scales. In implicit quantization mode, you configure or determine the activation scale using one of TensorRT’s calibration algorithms. TensorRT computes the weight scale according to the following formula:

\(\text{s}=\frac{max(abs(\mathrm{x}_{min}^{ch}),abs(\mathrm{x}_{max}^{ch}))}{127}\)

Where \({x}_{min}^{ch}\) and \({x}_{max}^{ch}\) are floating point minimum and maximum values for channel \(\text{ch}\) of the weights tensor.

Only explicit quantization is supported when using FP8; therefore, you are responsible for the values of the quantization scales.

\(x_{q}=quantize\left(x, s \right)=castToFp8(clip(\frac{x}{s}, -448,448))\)

Where:

  • \(\text{x}\) is a high-precision floating point value to be quantized.

  • \(x_{q}\) is a quantized FP8E4M3 value in the range [-448, 448].

  • \(\text{s}\) is the quantization scale expressed using a 16-bit or 32-bit floating point.

  • \(\text{castToFp8}\) rounds to the nearest value representable in FP8E4M3, ties are rounded to an even number, as described here.

\(\text{x}=dequantize\left(x_{q}, s\right)=x_{q}\ast s\)

Using FP8 and INT8 in the same network is not allowed.

Only explicit quantization is supported when using INT4, and you are therefore responsible for the values of the quantization scales.

\(x_{q}=quantize\left(x, s \right)=roundWithTiesToEven(clip(\frac{x}{s}, -8,7))\)

Where:

  • \(\text{x}\) is a high-precision floating point value to be quantized.

  • \(x_{q}\) is a quantized INT4 value in the range [-8, 7].

  • \(\text{s}\) is the quantization scale expressed using a 16-bit or 32-bit floating point.

  • \(\text{roundWithTiesToEven}\) is described here.

\(\text{x}=dequantize\left(x_{q}, s\right)=x_{q}\ast s\)

TensorRT only supports INT4 for weight quantization (Q/DQ Layer-Placement Recommendations).

Only explicit quantization is supported when using FP4, and you are therefore responsible for the values of the quantization scales.

\(x_{q}=quantize\left(x, s \right)=castToFp4(clip(\frac{x}{s}, -6,6))\)

Where:

  • \(\text{x}\) is a high-precision floating point value to be quantized.

  • \(x_{q}\) is a quantized FP4 value in the range [-6, 6].

  • \(\text{s}\) is the quantization scale expressed using a 16-bit or 32-bit floating point.

  • \(\text{castToFp4}\) rounds to the nearest value representable in FP4E2M1, ties are rounded to an even number, as described here.

\(\text{x}=dequantize\left(x_{q}, s\right)=x_{q}\ast s\)

When quantizing FP4 activations, Dynamic Quantization is recommended.

Quantization Modes#

There are three supported quantization scale granularities:

  1. Per-tensor quantization: a single scale value (scalar) is used to scale the entire tensor.

  2. Per-channel quantization: a scale tensor is broadcast along the given axis - for convolutional neural networks, this is typically the channel axis.

  3. Block quantization: the tensor is divided into fixed-size 1-dimensional blocks along a single dimension. A scale factor is defined for each block.

The quantization scale must contain all positive high-precision float coefficients (FP32, FP16, or BF16). The rounding method is round-to-nearest ties-to-even and clamps to the valid range, which is [-128, 127] for INT8, [-448, 448] for FP8, and [-8, 7] for INT4.

With explicit quantization, activations can only be quantized using per-tensor quantization. Weights can be quantized in any of the quantization modes.

In implicit quantization, weights are quantized by TensorRT during engine optimization, and only per-channel quantization is used. TensorRT quantizes weights for convolution, deconvolution, fully connected layers, and MatMul, where the second input is constant, and both input matrices are 2D.

When using per-channel quantization with Convolutions, the quantization axis must be the output-channel axis. For example, when the weights of 2D convolution are described using KCRS notation, K is the output-channel axis, and the weights quantization can be described as:

For each k in K:
    For each c in C:
        For each r in R:
            For each s in S:
                output[k,c,r,s] := clamp(round(input[k,c,r,s] / scale[k]))

The scale is a vector of coefficients and must have the same size as the quantization axis.

Dequantization is performed similarly except for the pointwise operation that is defined as:

output[k,c,r,s] := input[k,c,r,s] * scale[k]

Block Quantization

In block quantization, elements are grouped into 1D blocks, with all elements in a block sharing a common scale factor. Block quantization is supported for inputs of up to 3 dimensions.

INT4 block quantization supports weight-only quantization (WoQ).

FP4 block quantization supports both weights and activations. To minimize the quantization error, it is recommended to use Dynamic Quantization for activations.

When using block quantization, the scale tensor dimensions equal the data tensor dimensions except for one dimension over which blocking is performed (the blocking axis). For example, given a 2-D RS weights input, R (dimension 0) as the blocking axis and B as the block size, the scale in the blocking axis is repeated according to the block size and can be described like this:

For each r in R:
    For each s in S:
        output[r,s] = clamp(round(input[r,s] / scale[r//B, s]))

The scale is a 2D array of coefficients with dimensions (R//B, S).

Dequantization is performed similarly, except for the pointwise operation that is defined as:

output[r,s] = input[r,s] * scale[r//B, s]

Setting Dynamic Range#

The dynamic range API is only applicable to INT8 quantization.

TensorRT provides APIs to directly set the dynamic range (which must be represented by the quantized tensor) to support implicit quantization where these values have been calculated outside TensorRT.

The API allows setting the dynamic range for a tensor using minimum and maximum values. Since TensorRT currently supports only symmetric range, the scale is calculated using max(abs(min_float), abs(max_float)). Note that when abs(min_float) != abs(max_float), TensorRT uses a larger dynamic range than configured, which may increase the rounding error.

You can set the dynamic range for a tensor as follows:

1tensor->setDynamicRange(min_float, max_float);
1tensor.dynamic_range = (min_float, max_float)

sampleINT8API illustrates the use of these APIs in C++.

Post-Training Quantization Using Calibration#

Note

This section describes deprecated APIs. It is recommended to use explicit quantization.

Calibration is only applicable to INT8 quantization.

In post-training quantization, TensorRT computes a scale value for each tensor in the network. This process, called calibration, requires you to supply representative input data on which TensorRT runs the network to collect statistics for each activation tensor.

The amount of input data required is application-dependent, but experiments indicate that about 500 images are sufficient for calibrating ImageNet classification networks.

Given the statistics for an activation tensor, deciding on the best scale value is not an exact science - it requires balancing two sources of error in the quantized representation: discretization error (which increases as the range represented by each quantized value becomes larger) and truncation error (where values are clamped to the limits of the representable range.) Thus, TensorRT provides multiple calibrators that calculate the scale differently. Older calibrators also performed layer fusion for GPU to optimize away unneeded Tensors before calibration. This can be problematic when using DLA, where fusion patterns may be different and can be overridden using the kCALIBRATE_BEFORE_FUSION quantization flag.

Calibration batch size can also affect the truncation error for IInt8EntropyCalibrator2 and IInt8EntropyCalibrator. For example, calibrating using multiple small batches of calibration data may result in reduced histogram resolution and poor scale value. For each calibration step, TensorRT updates the histogram distribution for each activation tensor. Suppose it encounters a value in the activation tensor larger than the current histogram max. In that case, the histogram range is increased by a power of two to accommodate the new maximum value. This approach works well unless the histogram reallocates in the last calibration step, resulting in a final histogram with half the empty bins. Such a histogram can produce poor calibration scales. This also makes calibration susceptible to the order of calibration batches; a different order can increase the histogram size at different points, producing slightly different calibration scales. To avoid this issue, calibrate with as large a single batch as possible and ensure that calibration batches are well randomized and have a similar distribution.

IInt8EntropyCalibrator2

Entropy calibration chooses the tensor’s scale factor to optimize the quantized tensor’s information-theoretic content and usually suppresses outliers in the distribution. This is the current and recommended entropy calibrator and is required for DLA. Calibration happens before Layer fusion by default. Calibration batch size may impact the final result. It is recommended for CNN-based networks.

IInt8MinMaxCalibrator

This calibrator uses the entire range of the activation distribution to determine the scale factor. It works better for NLP tasks. Calibration happens before Layer fusion by default. This is recommended for networks such as NVIDIA BERT (an optimized version of Google’s official implementation).

IInt8EntropyCalibrator

This is the original entropy calibrator. It is less complicated than the LegacyCalibrator and typically produces better results. The calibration batch size may impact the final result. By default, calibration happens after Layer fusion.

IInt8LegacyCalibrator

This calibrator is for compatibility with TensorRT 2.0 EA. It requires user parameterization and is a fallback option if the other calibrators yield poor results. Calibration happens after Layer fusion by default. You can customize this calibrator to implement percentile max. For example, 99.99% percentile max is observed to have the best accuracy for NVIDIA BERT and NeMo ASR model QuartzNet.

When building an INT8 engine, the builder performs the following steps:

  1. Build a 32-bit engine, run it on the calibration set, and record a histogram for each tensor of the distribution of activation values.

  2. Build from the histograms a calibration table providing a scale value for each tensor.

  3. Build the INT8 engine from the calibration table and the network definition.

Calibration can be slow; therefore, the output of step 2 (the calibration table) can be cached and reused. This is useful when building the same network multiple times on a given platform and is supported by all calibrators.

Before running calibration, TensorRT queries the calibrator implementation to see if it has access to a cached table. If so, it proceeds directly to step 3. Cached data is passed as a pointer and length.

The calibration cache data is portable across different devices as long as the calibration happens before layer fusion. Specifically, the calibration cache is portable when using the IInt8EntropyCalibrator2 or IInt8MinMaxCalibrator calibrators or when QuantizationFlag::kCALIBRATE_BEFORE_FUSION is set. For example, this can simplify the workflow by building the calibration table on a machine with a discrete GPU and then reusing it on an embedded platform. Fusions are not guaranteed the same across platforms or devices, so calibrating after layer fusion may not result in a portable calibration cache. The calibration cache is, in general, not portable across TensorRT releases.

TensorRT uses symmetric quantization with a quantization scale calculated using the maximum absolute values found in the weight tensor. For convolution, deconvolution, and fully connected weights, scales are per-channel.

Note

When the builder is configured to use INT8 I/O, TensorRT still expects calibration data to be in FP32. You can create FP32 calibration data by casting INT8 I/O to FP32 precision. Also, FP32 cast calibration data should be in the range [-128.0F, 127.0F] and converted to INT8 data without any precision loss.

INT8 calibration can be used along with the dynamic range APIs. Setting the dynamic range manually overrides the dynamic range generated from INT8 calibration.

Note

Calibration is deterministic - that is, if you provide TensorRT with the same input to calibration in the same order on the same device, the scales generated will be the same across different runs. The data in the calibration cache will be bitwise identical when generated using the same device with the same batch size when provided with identical calibration inputs. The exact data in the calibration cache is not guaranteed to be bitwise identical when generated using different devices, batch sizes, or calibration inputs.

INT8 Calibration Using C++#

To provide calibration data to TensorRT, the IInt8Calibrator interface must be implemented.

The builder invokes the calibrator as follows:

  1. First, it queries the interface for the batch size and calls getBatchSize() to determine the size of the input batch to expect.

  2. Then, it repeatedly calls getBatch() to obtain batches of input. Batches must be exactly the batch size by getBatchSize(). When there are no more batches, getBatch() must return false.

  3. After you have implemented the calibrator, you can configure the builder to use it:

    config->setInt8Calibrator(calibrator.get());
    
  4. Implement the writeCalibrationCache() and readCalibrationCache() methods to cache the calibration table.

Calibration Using Python#

The following steps illustrate creating an INT8 calibrator object using the Python API.

  1. Import TensorRT.

    import tensorrt as trt
    
  2. Similar to test/validation datasets, use a set of input files as a calibration dataset. Ensure that the calibration files represent the overall inference data files. For TensorRT to use the calibration files, you must create a batchstream object. A batchstream object is used to configure the calibrator.

    NUM_IMAGES_PER_BATCH = 5
    batchstream = ImageBatchStream(NUM_IMAGES_PER_BATCH, calibration_files)
    
  3. Create an Int8_calibrator object with input node names and batch stream.

    Int8_calibrator = EntropyCalibrator(["input_node_name"], batchstream)
    
  4. Set INT8 mode and INT8 calibrator.

    config.set_flag(trt.BuilderFlag.INT8)
    config.int8_calibrator = Int8_calibrator
    

Quantization Noise Reduction#

For networks with implicit quantization, TensorRT attempts to reduce quantization noise in the output by forcing some layers near the network outputs to run in FP32, even if INT8 implementations are available.

The heuristic attempts to ensure that INT8 quantization is smooth by summating multiple quantized values. Layers considered “smoothing layers” are convolution, deconvolution, a fully connected layer, or matrix multiplication before reaching the network output. For example, if a network consists of a series of (convolution + activation + shuffle) subgraphs and the network output has type FP32, the last convolution will output FP32 precision, even if INT8 is allowed and faster.

The heuristic does not apply in the following scenarios:

  • The network output has type INT8.

  • An operation on the path (inclusively) from the last smoothing layer to the output is constrained by ILayer::setOutputType or ILayer::setPrecision to output INT8.

  • There is no smoothing layer with a path to the output, or that path has an intervening plugin layer.

  • The network uses explicit quantization.

Explicit Quantization#

When TensorRT detects the presence of Q/DQ layers in a network, it builds an engine using explicit-precision processing logic, and precision-control build flags are not required.

In explicit quantization, network representation changes to and from the quantized data type are explicit; therefore, INT8 and FP8 must not be used as type constraints.

For a strongly typed network, builder flags are neither required nor allowed.

Quantized Weights#

Weights of Q/DQ models may be specified using a high-precision data type (FP32, FP16, or BF16) or a low-precision quantized type (INT8, FP8, INT4, FP4). When TensorRT builds an engine, high-precision weights are quantized using the IQuantizeLayer scale, which operates on the weights. The quantized (low-precision) weights are stored in the engine plan file. When using pre-quantized weights (low precision), an IDequantizeLayer is required between the weights and the linear operator using the weights.

INT4 and FP4 quantized weights are stored by packing two elements per byte. The first element is stored in the 4 least significant bits, and the second is stored in the 4 most significant bits, as illustrated in the diagram below.

4-bit packing (logical tensor on the left; physical layout on the right).

The diagram below shows an example of packing a (2, 3) 4-bit tensor.

An example packed 4-bit (2, 3) tensor.

ONNX Support#

When a model trained in PyTorch or TensorFlow using Quantization Aware Training (QAT) is exported to ONNX, each fake-quantization operation in the framework’s graph is exported as a pair of QuantizeLinear and DequantizeLinear ONNX operators. When TensorRT imports ONNX models, the ONNX QuantizeLinear operator is imported as an IQuantizeLayer instance, and the ONNX DequantizeLinear operator is imported as an IDequantizeLayer instance.

ONNX introduced support for QuantizeLinear and DequantizeLinear in opset 10, and a quantization-axis attribute was added in opset 13 (required for per-channel quantization). PyTorch 1.8 introduced support for exporting PyTorch models to ONNX using opset 13.

ONNX opset 19 added four FP8 formats, of which TensorRT supports E4M3FN (also referred to as tensor (float8e4m3fn) in the ONNX operator schema). The latest Pytorch version (Pytorch 2.0) does not support FP8 formats, nor does it support export to ONNX using opset 19.

To bridge the gap, TransformerEngine exports its FP quantization functions as custom ONNX Q/DQ operators that belong to the “trt” domain (TRT_FP8 QuantizeLinear and TRT_FP8 DequantizeLinear). TensorRT can parse both the custom operators and standard opset 19 Q/DQ operators; however, it is noted that opset 19 is not fully supported by TensorRT. Other tools like ONNX Runtime cannot parse the custom operators. ONNX opset 21 added support for INT4 data type and block quantization. ONNX opset 23 added support for FP4E2M1 type.

Warning

The ONNX GEMM operator is an example that can be quantized per channel. PyTorch torch.nn.Linear layers are exported as an ONNX GEMM operator with (K, C) weights layout and the transB GEMM attribute enabled (this transposes the weights before performing the GEMM operation). TensorFlow, on the other hand, pre-transposes the weights (C, K) before ONNX export:

  • PyTorch: \(y=xW^{T}\)

  • TensorFlow: \(y=xW\)

TensorRT, therefore, transposes pyTorch weights. TensorRT quantizes the weights before they are transposed, so GEMM layers originating from ONNX QAT models that were exported from PyTorch use dimension 0 for per-channel quantization (axis K = 0), while models originating from TensorFlow use dimension 1 (axis K = 1).

TensorRT does not support pre-quantized ONNX models that use INT8/FP8 quantized operators. Specifically, the following ONNX quantized operators are not supported and generate an import error if they are encountered when TensorRT imports the ONNX model:

TensorRT Processing of Q/DQ Networks#

When TensorRT optimizes a network in Q/DQ mode, the optimization process is limited to optimizations that do not change the arithmetic correctness of the network. Bit-level accuracy is rarely possible since the order of floating-point operations can produce different results (for example, rewriting \(\text{a}\ast s+b\ast s\) as \(\left(a+b \right)\ast s\) is a valid optimization). Allowing these differences is fundamental to backend optimization in general, and this also applies to converting a graph with Q/DQ layers to use quantized operations.

Q/DQ layers control the compute and data precision of a network. An IQuantizeLayer instance converts a high-precision floating-point tensor to a quantized tensor by employing quantization, and an IDequantizeLayer instance converts a quantized tensor to a high-precision floating-point tensor using dequantization. TensorRT expects a Q/DQ layer pair on each input of quantizable layers. Quantizable layers are deep-learning layers that can be converted to quantized layers by fusing with IQuantizeLayer and IDequantizeLayer instances. When TensorRT performs these fusions, it replaces the quantizable layers with quantized layers that operate on quantized data using compute operations suitable for quantized types.

For the diagrams used in this chapter, green designates low precision (quantized), and blue designates high precision. Arrows represent network activation tensors, and squares represent network layers.

A quantizable ``AveragePool`` layer (in blue) is fused with DQ and Q layers. All three layers are replaced by a quantized ``AveragePool`` layer (in green).

During network optimization, TensorRT moves Q/DQ layers in Q/DQ propagation. The goal in propagation is to maximize the proportion of the graph that can be processed at low precision. Thus, TensorRT propagates Q nodes backward (quantization happens as early as possible) and DQ nodes forward (so dequantization happens as late as possible). Q-layers can swap places with layers that commute with Quantization, and DQ-layers can swap places with layers that commute with Dequantization.

A layer \(\text{Op}\) commutes with quantization if \(\text{Q}\left(Op\left(x \right) \right)==\text{Op}\left(Q\left(x \right) \right)\)

Similarly, a layer \(\text{Op}\) commutes with dequantization if \(\text{Op}\left(DQ\left(x \right) \right)==\text{DQ}\left(Op\left(x \right) \right)\)

The following diagram illustrates DQ forward propagation and Q backward propagation. These are legal rewrites of the model because Max Pooling has an INT8 implementation and commutes with DQ and Q.

An illustration depicting a DQ forward-propagation and Q backward-propagation

To understand Max Pooling commutation, let us look at the output of the maximum-pooling operation applied to some arbitrary input. Max Pooling is applied to groups of input coefficients and outputs the coefficient with the maximum value. For group i composed of coefficients \(\left\{x_{0}..x_{m} \right\}\):

\(output_{i}:=max\left( \left\{ x_{0},x_{1},...x_{m} \right\} \right)=max\left( \left\{max\left( \left\{ max\left( \left\{ x_{0},x_{1} \right\} \right),x_{2} \right\} \right),...x_{3} \right\} \right)\)

It is, therefore, enough to look at two arbitrary coefficients without loss of generality (WLOG): \(x_{j}=max\left( \left\{ x_{j},x_{k} \right\} \right)for x_{j}\ge x_{k}\)

For the quantization function \(\text{Q}\left( a,scale,x_{max},x_{min} \right):=truncate\left( round\left( \frac{a}{scale} \right),x_{max},x_{min}\right) scale\gt 0\), note that (without providing proof and using simplified notation): \(\text{Q}\left( x_{j},scale \right)\ge \text{Q}\left( x_{k},scale \right)for x_{j}\ge x_{k}\)

Therefore: \(\text{max}\left( \left\{ \text{Q}\left( x_{j},scale \right),\text{Q}\left( x_{k},scale \right) \right\} \right)=\text{Q}\left( x_{j},scale \right) for x_{j}\ge x_{k}\)

However, by definition: \(\text{Q}\left( max\left( \left\{ x_{j},x_{k} \right\} \right),scale \right)=\text{Q}\left( x_{j},scale \right) for x_{j}\ge x_{k}\)

Function \(\text{max}\) commutes with quantization, and so does Max Pooling.

Similarly, for dequantization, function \(\text{DQ}\left( a,scale \right):=a\ast scale\) with \(\text{scale}\gt 0\) it can be shown that: \(\text{max}\left( \left\{ \text{DQ}\left(x_{j},scale \right),\text{DQ}\left( x_{k},scale \right) \right\} \right)=\text{DQ}\left( x_{j},scale \right)=\text{DQ}\left( \text{max}\left( \left\{ x_{j},x_{k} \right\} \right),scale \right) for x_{j}\ge x_{k}\)

There is a distinction between how quantizable layers and commuting layers are processed. Both layers can be computed in INT8/FP8, but quantizable layers also fuse with a DQ input and a Q output layer. For example, an AveragePooling layer (quantizable) does not commute with either Q or DQ, so it is quantized using Q/DQ fusion, as illustrated in the first diagram. This is in contrast to how Max Pooling (commuting) is quantized.

Weight-Only Quantization#

Weight-only quantization (WoQ) is an optimization useful when memory bandwidth limits the performance of GEMM operations or when GPU memory is scarce. In WoQ, GEMM weights are quantized to INT4 precision while the GEMM input data and compute operation remain high precision. TensorRT’s WoQ kernels read the 4-bit weights from memory and dequantize them before performing the dot product in high precision.

Weight-only Quantization (WoQ)

WoQ is available only for INT4 block quantization with GEMM layers. The GEMM data input is specified in high-precision (FP32, FP16, BF16), and the weights are quantized using Q/DQ as usual. TensorRT creates an engine with INT4 weights and a high-precision GEMM operation. The engine reads the low-precision weights and dequantizes them before performing the GEMM operation in high-precision.

Q/DQ Layer-Placement Recommendations#

The placement of Q/DQ layers in a network affects performance and accuracy. Aggressive quantization can lead to degradation in model accuracy because of the error introduced by quantization. But quantization also enables latency reductions. Listed here are some recommendations for placing Q/DQ layers in your network.

Note that older devices may not have low-precision kernel implementations for all layers, and you may encounter a could not find any implementation error while building your engine. To resolve this, remove the Q/DQ nodes, which quantize the failing layers.

Quantize all inputs of weighted operations (Convolution, Transposed Convolution, and GEMM). Quantizing the weights and activations reduces bandwidth requirements and enables INT8 computation to accelerate bandwidth-limited and compute-limited layers.

Two examples of how TensorRT fuses convolutional layers. On the left, only the input is quantized. On the right, both the input and output are quantized.

By default, do not quantize the outputs of weighted operations. It is sometimes useful to preserve the higher-precision dequantized output. For example, if the linear operation is followed by an activation function (SiLU, in the following diagram), it requires higher precision input to produce acceptable accuracy.

Example of a linear operation followed by an activation function

Do not simulate batch normalization and ReLU fusions in the training framework because TensorRT optimizations guarantee the preservation of these operations’ arithmetic semantics.

Batch normalization is fused with convolution and ReLU while keeping the same execution order defined in the pre-fusion network. There is no need to simulate BN-folding in the training network.

Quantize the residual input in skip connections. TensorRT can fuse element-wise addition following weighted layers, which is useful for models with skip connections like ResNet and EfficientNet. The precision of the first input to the element-wise addition layer determines the fusion output’s precision.

For example, in the following diagram, the precision of \(x_{f^{}}^{1}\) is a floating point, so the output of the fused convolution is limited to the floating point, and the trailing Q-layer cannot be fused with the convolution.

The precision of :math:`x_{f^{}}^{1}` is a floating point, so the output of the fused convolution is limited to the floating point, and the trailing Q-layer cannot be fused with the convolution.

In contrast, when \(x_{f^{}}^{1}\) is quantized to INT8, as depicted in the following diagram, the output of the fused convolution is also INT8, and the trailing Q-layer is fused with the convolution.

When :math:`x_{f^{}}^{1}` is quantized to INT8, the output of the fused convolution is also INT8, and the trailing Q-layer is fused with the convolution.

For extra performance, try quantizing layers that do not commute with Q/DQ. Currently, non-weighted layers with INT8 inputs also require INT8 outputs, so quantize both inputs and outputs.

An example of quantizing a quantizable operation. An element-wise addition is fused with the input DQs and the output Q.

Performance can decrease if TensorRT cannot fuse the operations with the surrounding Q/DQ layers, so be conservative when adding Q/DQ nodes and experiment with accuracy and TensorRT performance in mind.

The following figure shows suboptimal fusions (the highlighted light green background rectangles) that can result from extra Q/DQ operations. The convolution is fused separately from the element-wise addition because Q/DQ pairs surround each other.

An example of suboptimal quantization fusions: contrast the suboptimal fusion in A and the optimal fusion in B. The extra pair of Q/DQ operations (highlighted with a glowing green border) forces the separation of the convolution from the element-wise addition.

Use per-tensor quantization for activations and per-channel quantization for weights. This configuration has been demonstrated empirically to lead to the best quantization accuracy.

You can further optimize engine latency by enabling FP16. TensorRT attempts to use FP16 instead of FP32 whenever possible (this is not currently supported for all layer types).

Q/DQ Limitations#

A few Q/DQ graph-rewrite optimizations that TensorRT performs compare the values of quantization scales between two or more Q/DQ layers and only perform the graph-rewrite if the compared quantization scales are equal. When a refittable TensorRT engine is refitted, the scales of Q/DQ nodes can be assigned new values. During the refitting operation of Q/DQ engines, TensorRT checks if Q/DQ layers that participated in scale-dependent optimizations are assigned new values that break the rewrite optimizations and throw an exception if true.

An example showing scales of Q1 and Q2 are compared for equality, and if equal, they are allowed to propagate backward. If the engine is refitted with new values for Q1 and Q2 such that Q1 != Q2, then an exception aborts the refitting process.

Q/DQ Interaction with Plugins#

Plugins extend TensorRT’s capabilities by allowing the replacement of a group of layers with a custom and proprietary implementation. You can decide what functionality to include in the plugin and what to leave for TensorRT to handle.

The same applies to a TensorRT network with Q/DQ layers. When a plugin consumes quantized inputs (INT8/FP8) and generates quantized outputs, the input DQ and output Q nodes must be included in the plugin and removed from the network.

Consider a simple case of a sequential graph consisting of a single INT8 plugin (aptly named MyInt8Plugin) sandwiched between two convolution layers (ignoring weights quantization):

\(\text{Input}\gt \text{Q}\to \text{DQ}\gt \text{Conv}\gt \text{Q}\to \text{DQ_i}\gt \text{MyInt8Plugin}\gt \text{Q_o}\to \text{DQ}\gt \text{Conv}\gt \text{Output}\)

The \(\gt\) arrows indicate activation tensors with FP32 precision, and the \(\to\) arrows indicate INT8 precision.

When TensorRT optimizes this graph, it fuses the layers to the following graph (square brackets indicate TensorRT fusions):

\(\text{Input}\gt \text{Q}\to \left[\text{DQ}\to \text{Conv} \to \text{Q} \right]\to \text{DQ_i}\gt \text{MyInt8Plugin}\gt \text{Q_o}\to \left[ \text{DQ}\to \text{Conv} \right]\gt \text{Output}\)

In the graph above, the plugin consumes and generates FP32 inputs and outputs. Since the plugin MyInt8Plugin uses INT8 precision, the subsequent procedure involves the manual integration of DQ_i and Q_o with the MyInt8Plugin, followed by invoking the setOutputType(kINT8) method for this particular plugin layer; TensorRT will see a network like this:

\(\text{Input}\gt \text{Q}\to \text{DQ}\gt \text{Conv}\gt \text{Q}\to \text{MyInt8Plugin}\to \text{DQ}\gt \text{Conv}\gt \text{Output}\)

Which it will fuse to:

\(\text{Input}\gt \text{Q}\to \left[ \text{DQ}\to \text{Conv}\to \text{Q} \right]\gt \text{MyInt8Plugin}\to \left[ \text{DQ}\to \text{Conv} \right]\gt \text{Output}\)

When “manually fusing” DQ_i, you take the input quantization scale and give it to your plugin so it will know how to dequantize (if needed) the input. The same applies to using the scale from Q_o to quantize your plugin’s output.

QAT Networks Using TensorFlow#

You can use the TensorRT Model Optimizer to perform QAT in TensorFlow 2 Keras models following NVIDIA’s QAT recipe. This leads to optimal model acceleration with TensorRT on NVIDIA GPUs and hardware accelerators.

TensorFlow 1 does not support per-channel quantization (PCQ), which is recommended for weights to preserve the model’s accuracy.

QAT Networks Using PyTorch#

PyTorch 1.8.0 and forward support ONNX QuantizeLinear and DequantizeLinear support per channel scales.

You can use the TensorRT Model Optimizer to calibrate INT8, perform QAT and PTQ for the various precisions that TensorRT supports, and export to ONNX.

QAT Networks Using TransformerEngine#

We provide TransformerEngine, an open-source library for accelerating transformer models’ training, inference, and exporting. It includes APIs for building a Transformer layer and a framework-agnostic library in C++, including structs and kernels needed for FP8 support. Modules provided by TransformerEngine internally maintain scaling factors and other values needed for FP8 training. You can use TransformerEngine to train a mixed precision model, export an ONNX model, and use TensorRT to run inference on this ONNX model.

Dynamic Quantization#

Dynamic Quantization is a form of Block Quantization in which the scales are computed during inference according to the input data. It produces two outputs: quantized data and per-block scales.

Dynamic Quantization has two main benefits:

  1. Accuracy: With dynamic Quantization, a scale is selected to map only the dynamic range of a single block to the quantized type. Since the dynamic range of a single block is often much smaller than the dynamic range of the entire tensor, the quantization error is reduced. This is mostly significant for sub-byte quantized types due to the small range of values representable in these data types.

  2. Reduced PTQ overhead: Since the scales are automatically computed during inference, the user is not required to calibrate the scales according to sample data.

For each block, the scale is computed by:

\(\text{scale}=max_{i\in \left\{ 0...blockSize-1 \right\}}\left( \frac{abs\left( x_{i} \right)}{qTypeMax} \right)\)

Where:

  • \(\text{qTypeMax}\) is the maximum value in the quantized type (for example, 6 for FP4E2M1).

TensorRT supports a form of Dynamic Quantization called Dynamic Double Quantization, in which the computed scales are also quantized. Putting together the scale computation and scale quantization for a single block:

\(scale_{quantized}=quantize\left( max_{i\in \left\{ 0...blockSize-1 \right\}}\left( \frac{abs\left( x_{i} \right)}{qTypeMax} \right),scale=globalSf \right)\)

Where:

  • \(\text{globalSf}\) is an offline-calibrated per-tensor quantization scale (scalar).

  • \(\text{qTypeMax}\) is the maximum value representable in the quantized type used for data.

The scale computation is repeated for each block, computing a total of \(\frac{inputVolume}{blockSize}\) block scales.

TensorRT currently supports Dynamic Double Quantization only for FP4 data and FP8 scales. Using \(qTypeMax=6\) and the FP8 range of [-448,448], the quantized scale can be written as:

\(scale_{fp8}=castToFp8\left( \frac{max_{i\in \left( 0...blockSize-1 \right)}\left( abs\left(x_{i} \right) \right)}{6\ast globalSf} \right)\)

The quantized data is computed using block quantization using the computed scales.

To dequantize data that was quantized using Dynamic Double Quantization, two consecutive Dequantize operations must occur: the first to dequantize the scales using per-tensor quantization, and the second to dequantize the data.

\(data_{DQ}=dequantize\left( data_{Q},dequantize\left( scale_{Q},scale=globalSf \right) \right)\)

An example showing the fusion of Dynamic Double Quantization with GEMM; where: ``hp`` is high precision and ``sf`` is scale factors

Quantized Types Rounding Modes#

Quantized Types Rounding Modes#

Backend

Compute Kernel Quantization

Weights Quantization (FP32 to INT8/FP8/INT4/FP4)

Quantized Network (QAT)

Dynamic Range API / Calibration

GPU

round-to-nearest-with-ties-to-even (INT8, FP8, INT4, FP4)

round-to-nearest-with-ties-to-even

round-to-nearest-with-ties-to-positive-infinity (INT8 only)

DLA

round-to-nearest-with-ties-to-even

N/A

round-to-nearest-with-ties-to-even (INT8 only)