Getting Started: End to End

NVIDIA® TensorFlow 2.x Quantization toolkit provides a simple API to quantize a given Keras model. At a higher level, Quantization Aware Training (QAT) is a three-step workflow as shown below:

flowchart LR id1(Pre-trained model) --> id2(Quantize) --> id3(Fine-tune)

Initially, the network is trained on the target dataset until fully converged. The Quantization step consists of inserting Q/DQ nodes in the pre-trained network to simulate quantization during training. Note that simply adding Q/DQ nodes will result in reduced accuracy since the quantization parameters are not yet updated for the given model. The network is then re-trained for a few epochs to recover accuracy in a step called “fine-tuning”.

Goal

  1. Train a simple network on the Fashion MNIST dataset and save it as the baseline model.

  2. Quantize the pre-trained baseline network.

  3. Fine-tune the quantized network to recover accuracy and save it as the QAT model.


1. Train

Import required libraries and create a simple network with convolution and dense layers.

import tensorflow as tf
from tensorflow_quantization import quantize_model
from tensorflow_quantization import utils

assets = utils.CreateAssetsFolders("GettingStarted")
assets.add_folder("example")

def simple_net():
    """
    Return a simple neural network.
    """
    input_img = tf.keras.layers.Input(shape=(28, 28), name="nn_input")
    x = tf.keras.layers.Reshape(target_shape=(28, 28, 1), name="reshape_0")(input_img)
    x = tf.keras.layers.Conv2D(filters=126, kernel_size=(3, 3), name="conv_0")(x)
    x = tf.keras.layers.ReLU(name="relu_0")(x)
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), name="conv_1")(x)
    x = tf.keras.layers.ReLU(name="relu_1")(x)
    x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), name="conv_2")(x)
    x = tf.keras.layers.ReLU(name="relu_2")(x)
    x = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3), name="conv_3")(x)
    x = tf.keras.layers.ReLU(name="relu_3")(x)
    x = tf.keras.layers.Conv2D(filters=8, kernel_size=(3, 3), name="conv_4")(x)
    x = tf.keras.layers.ReLU(name="relu_4")(x)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name="max_pool_0")(x)
    x = tf.keras.layers.Flatten(name="flatten_0")(x)
    x = tf.keras.layers.Dense(100, name="dense_0")(x)
    x = tf.keras.layers.ReLU(name="relu_5")(x)
    x = tf.keras.layers.Dense(10, name="dense_1")(x)
    return tf.keras.Model(input_img, x, name="original")

# create model
model = simple_net()
model.summary()
Model: "original"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 nn_input (InputLayer)       [(None, 28, 28)]          0         
                                                                 
 reshape_0 (Reshape)         (None, 28, 28, 1)         0         
                                                                 
 conv_0 (Conv2D)             (None, 26, 26, 126)       1260      
                                                                 
 relu_0 (ReLU)               (None, 26, 26, 126)       0         
                                                                 
 conv_1 (Conv2D)             (None, 24, 24, 64)        72640     
                                                                 
 relu_1 (ReLU)               (None, 24, 24, 64)        0         
                                                                 
 conv_2 (Conv2D)             (None, 22, 22, 32)        18464     
                                                                 
 relu_2 (ReLU)               (None, 22, 22, 32)        0         
                                                                 
 conv_3 (Conv2D)             (None, 20, 20, 16)        4624      
                                                                 
 relu_3 (ReLU)               (None, 20, 20, 16)        0         
                                                                 
 conv_4 (Conv2D)             (None, 18, 18, 8)         1160      
                                                                 
 relu_4 (ReLU)               (None, 18, 18, 8)         0         
                                                                 
 max_pool_0 (MaxPooling2D)   (None, 9, 9, 8)           0         
                                                                 
 flatten_0 (Flatten)         (None, 648)               0         
                                                                 
 dense_0 (Dense)             (None, 100)               64900     
                                                                 
 relu_5 (ReLU)               (None, 100)               0         
                                                                 
 dense_1 (Dense)             (None, 10)                1010      
                                                                 
=================================================================
Total params: 164,058
Trainable params: 164,058
Non-trainable params: 0
_________________________________________________________________

Load Fashion MNIST data and split train and test sets.

# Load Fashion MNIST dataset
mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0   

Compile the model and train for five epochs.

# Train original classification model
model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

model.fit(
    train_images, train_labels, batch_size=128, epochs=5, validation_split=0.1
)

# get baseline model accuracy
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0
)
print("Baseline test accuracy:", baseline_model_accuracy)
Epoch 1/5
422/422 [==============================] - 4s 8ms/step - loss: 0.5639 - accuracy: 0.7920 - val_loss: 0.4174 - val_accuracy: 0.8437
Epoch 2/5
422/422 [==============================] - 3s 8ms/step - loss: 0.3619 - accuracy: 0.8696 - val_loss: 0.4134 - val_accuracy: 0.8433
Epoch 3/5
422/422 [==============================] - 3s 8ms/step - loss: 0.3165 - accuracy: 0.8855 - val_loss: 0.3137 - val_accuracy: 0.8812
Epoch 4/5
422/422 [==============================] - 3s 8ms/step - loss: 0.2787 - accuracy: 0.8964 - val_loss: 0.2943 - val_accuracy: 0.8890
Epoch 5/5
422/422 [==============================] - 3s 8ms/step - loss: 0.2552 - accuracy: 0.9067 - val_loss: 0.2857 - val_accuracy: 0.8952
Baseline test accuracy: 0.888700008392334
# save TF FP32 original model
tf.keras.models.save_model(model, assets.example.fp32_saved_model)

# Convert FP32 model to ONNX
utils.convert_saved_model_to_onnx(saved_model_dir = assets.example.fp32_saved_model, onnx_model_path = assets.example.fp32_onnx_model)
INFO:tensorflow:Assets written to: GettingStarted/example/fp32/saved_model/assets
INFO:tensorflow:Assets written to: GettingStarted/example/fp32/saved_model/assets
ONNX conversion Done!

2. Quantize

Full model quantization is the most basic quantization mode someone can follow. In this mode, Q/DQ nodes are inserted in all supported keras layers, with a single function call:

# Quantize model
quantized_model = quantize_model(model)

Keras model summary shows all supported layers wrapped into QDQ wrapper class.

quantized_model.summary()
Model: "original"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 nn_input (InputLayer)       [(None, 28, 28)]          0         
                                                                 
 reshape_0 (Reshape)         (None, 28, 28, 1)         0         
                                                                 
 quant_conv_0 (Conv2DQuantiz  (None, 26, 26, 126)      1515      
 eWrapper)                                                       
                                                                 
 relu_0 (ReLU)               (None, 26, 26, 126)       0         
                                                                 
 quant_conv_1 (Conv2DQuantiz  (None, 24, 24, 64)       72771     
 eWrapper)                                                       
                                                                 
 relu_1 (ReLU)               (None, 24, 24, 64)        0         
                                                                 
 quant_conv_2 (Conv2DQuantiz  (None, 22, 22, 32)       18531     
 eWrapper)                                                       
                                                                 
 relu_2 (ReLU)               (None, 22, 22, 32)        0         
                                                                 
 quant_conv_3 (Conv2DQuantiz  (None, 20, 20, 16)       4659      
 eWrapper)                                                       
                                                                 
 relu_3 (ReLU)               (None, 20, 20, 16)        0         
                                                                 
 quant_conv_4 (Conv2DQuantiz  (None, 18, 18, 8)        1179      
 eWrapper)                                                       
                                                                 
 relu_4 (ReLU)               (None, 18, 18, 8)         0         
                                                                 
 max_pool_0 (MaxPooling2D)   (None, 9, 9, 8)           0         
                                                                 
 flatten_0 (Flatten)         (None, 648)               0         
                                                                 
 quant_dense_0 (DenseQuantiz  (None, 100)              65103     
 eWrapper)                                                       
                                                                 
 relu_5 (ReLU)               (None, 100)               0         
                                                                 
 quant_dense_1 (DenseQuantiz  (None, 10)               1033      
 eWrapper)                                                       
                                                                 
=================================================================
Total params: 164,791
Trainable params: 164,058
Non-trainable params: 733
_________________________________________________________________

Let’s check the quantized model’s accuracy immediately after Q/DQ nodes are inserted.

# Compile quantized model
quantized_model.compile(
    optimizer=tf.keras.optimizers.Adam(0.0001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)
# Get accuracy immediately after QDQ nodes are inserted.
_, q_aware_model_accuracy = quantized_model.evaluate(test_images, test_labels, verbose=0)
print("Quantization test accuracy immediately after QDQ insertion:", q_aware_model_accuracy)
Quantization test accuracy immediately after QDQ insertion: 0.883899986743927

The model’s accuracy decreases a bit as soon as Q/DQ nodes are inserted, requiring fine-tuning to recover it.

Note

Since this is a very small model, accuracy drop is small. For standard models like ResNets, accuracy drop immediately after QDQ insertion can be significant.

3. Fine-tune

Since the quantized model behaves similar to the original keras model, the same training recipe can be used for fine-tuning as well.

We fine-tune the model for two epochs and evaluate the model on the test dataset.

# fine tune quantized model for 2 epochs.
quantized_model.fit(
    train_images, train_labels, batch_size=32, epochs=2, validation_split=0.1
)
# Get quantized accuracy
_, q_aware_model_accuracy_finetuned = quantized_model.evaluate(test_images, test_labels, verbose=0)
print("Quantization test accuracy after fine-tuning:", q_aware_model_accuracy_finetuned)
print("Baseline test accuracy (for reference):", baseline_model_accuracy)
Epoch 1/2
1688/1688 [==============================] - 26s 15ms/step - loss: 0.1793 - accuracy: 0.9340 - val_loss: 0.2468 - val_accuracy: 0.9112
Epoch 2/2
1688/1688 [==============================] - 25s 15ms/step - loss: 0.1725 - accuracy: 0.9373 - val_loss: 0.2484 - val_accuracy: 0.9070
Quantization test accuracy after fine-tuning: 0.9075999855995178
Baseline test accuracy (for reference): 0.888700008392334

Note

If the network is not fully converged, the fine-tuned model’s accuracy can surpass the original model’s accuracy.

# save TF INT8 original model
tf.keras.models.save_model(quantized_model, assets.example.int8_saved_model)

# Convert INT8 model to ONNX
utils.convert_saved_model_to_onnx(saved_model_dir = assets.example.int8_saved_model, onnx_model_path = assets.example.int8_onnx_model)

tf.keras.backend.clear_session()
WARNING:absl:Found untraced functions such as conv_0_layer_call_fn, conv_0_layer_call_and_return_conditional_losses, conv_1_layer_call_fn, conv_1_layer_call_and_return_conditional_losses, conv_2_layer_call_fn while saving (showing 5 of 14). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: GettingStarted/example/int8/saved_model/assets
INFO:tensorflow:Assets written to: GettingStarted/example/int8/saved_model/assets
ONNX conversion Done!

In this example, accuracy loss due to quantization is recovered in just two epochs.

This NVIDIA® Quantization Toolkit provides an easy interface to create quantized networks, and thus take advantage of INT8 inference on NVIDIA® GPUs using TensorRT™.