Getting Started: End to End¶
NVIDIA® TensorFlow 2.x Quantization toolkit provides a simple API to quantize a given Keras model. At a higher level, Quantization Aware Training (QAT) is a three-step workflow as shown below:
Initially, the network is trained on the target dataset until fully converged. The Quantization step consists of inserting Q/DQ nodes in the pre-trained network to simulate quantization during training. Note that simply adding Q/DQ nodes will result in reduced accuracy since the quantization parameters are not yet updated for the given model. The network is then re-trained for a few epochs to recover accuracy in a step called “fine-tuning”.
Goal
Train a simple network on the Fashion MNIST dataset and save it as the baseline model.
Quantize the pre-trained baseline network.
Fine-tune the quantized network to recover accuracy and save it as the QAT model.
1. Train¶
Import required libraries and create a simple network with convolution and dense layers.
import tensorflow as tf
from tensorflow_quantization import quantize_model
from tensorflow_quantization import utils
assets = utils.CreateAssetsFolders("GettingStarted")
assets.add_folder("example")
def simple_net():
"""
Return a simple neural network.
"""
input_img = tf.keras.layers.Input(shape=(28, 28), name="nn_input")
x = tf.keras.layers.Reshape(target_shape=(28, 28, 1), name="reshape_0")(input_img)
x = tf.keras.layers.Conv2D(filters=126, kernel_size=(3, 3), name="conv_0")(x)
x = tf.keras.layers.ReLU(name="relu_0")(x)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), name="conv_1")(x)
x = tf.keras.layers.ReLU(name="relu_1")(x)
x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), name="conv_2")(x)
x = tf.keras.layers.ReLU(name="relu_2")(x)
x = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3), name="conv_3")(x)
x = tf.keras.layers.ReLU(name="relu_3")(x)
x = tf.keras.layers.Conv2D(filters=8, kernel_size=(3, 3), name="conv_4")(x)
x = tf.keras.layers.ReLU(name="relu_4")(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name="max_pool_0")(x)
x = tf.keras.layers.Flatten(name="flatten_0")(x)
x = tf.keras.layers.Dense(100, name="dense_0")(x)
x = tf.keras.layers.ReLU(name="relu_5")(x)
x = tf.keras.layers.Dense(10, name="dense_1")(x)
return tf.keras.Model(input_img, x, name="original")
# create model
model = simple_net()
model.summary()
Model: "original"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
nn_input (InputLayer) [(None, 28, 28)] 0
reshape_0 (Reshape) (None, 28, 28, 1) 0
conv_0 (Conv2D) (None, 26, 26, 126) 1260
relu_0 (ReLU) (None, 26, 26, 126) 0
conv_1 (Conv2D) (None, 24, 24, 64) 72640
relu_1 (ReLU) (None, 24, 24, 64) 0
conv_2 (Conv2D) (None, 22, 22, 32) 18464
relu_2 (ReLU) (None, 22, 22, 32) 0
conv_3 (Conv2D) (None, 20, 20, 16) 4624
relu_3 (ReLU) (None, 20, 20, 16) 0
conv_4 (Conv2D) (None, 18, 18, 8) 1160
relu_4 (ReLU) (None, 18, 18, 8) 0
max_pool_0 (MaxPooling2D) (None, 9, 9, 8) 0
flatten_0 (Flatten) (None, 648) 0
dense_0 (Dense) (None, 100) 64900
relu_5 (ReLU) (None, 100) 0
dense_1 (Dense) (None, 10) 1010
=================================================================
Total params: 164,058
Trainable params: 164,058
Non-trainable params: 0
_________________________________________________________________
Load Fashion MNIST data and split train and test sets.
# Load Fashion MNIST dataset
mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
Compile the model and train for five epochs.
# Train original classification model
model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)
model.fit(
train_images, train_labels, batch_size=128, epochs=5, validation_split=0.1
)
# get baseline model accuracy
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0
)
print("Baseline test accuracy:", baseline_model_accuracy)
Epoch 1/5
422/422 [==============================] - 4s 8ms/step - loss: 0.5639 - accuracy: 0.7920 - val_loss: 0.4174 - val_accuracy: 0.8437
Epoch 2/5
422/422 [==============================] - 3s 8ms/step - loss: 0.3619 - accuracy: 0.8696 - val_loss: 0.4134 - val_accuracy: 0.8433
Epoch 3/5
422/422 [==============================] - 3s 8ms/step - loss: 0.3165 - accuracy: 0.8855 - val_loss: 0.3137 - val_accuracy: 0.8812
Epoch 4/5
422/422 [==============================] - 3s 8ms/step - loss: 0.2787 - accuracy: 0.8964 - val_loss: 0.2943 - val_accuracy: 0.8890
Epoch 5/5
422/422 [==============================] - 3s 8ms/step - loss: 0.2552 - accuracy: 0.9067 - val_loss: 0.2857 - val_accuracy: 0.8952
Baseline test accuracy: 0.888700008392334
# save TF FP32 original model
tf.keras.models.save_model(model, assets.example.fp32_saved_model)
# Convert FP32 model to ONNX
utils.convert_saved_model_to_onnx(saved_model_dir = assets.example.fp32_saved_model, onnx_model_path = assets.example.fp32_onnx_model)
INFO:tensorflow:Assets written to: GettingStarted/example/fp32/saved_model/assets
INFO:tensorflow:Assets written to: GettingStarted/example/fp32/saved_model/assets
ONNX conversion Done!
2. Quantize¶
Full model quantization is the most basic quantization mode someone can follow. In this mode, Q/DQ nodes are inserted in all supported keras layers, with a single function call:
# Quantize model
quantized_model = quantize_model(model)
Keras model summary shows all supported layers wrapped into QDQ wrapper class.
quantized_model.summary()
Model: "original"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
nn_input (InputLayer) [(None, 28, 28)] 0
reshape_0 (Reshape) (None, 28, 28, 1) 0
quant_conv_0 (Conv2DQuantiz (None, 26, 26, 126) 1515
eWrapper)
relu_0 (ReLU) (None, 26, 26, 126) 0
quant_conv_1 (Conv2DQuantiz (None, 24, 24, 64) 72771
eWrapper)
relu_1 (ReLU) (None, 24, 24, 64) 0
quant_conv_2 (Conv2DQuantiz (None, 22, 22, 32) 18531
eWrapper)
relu_2 (ReLU) (None, 22, 22, 32) 0
quant_conv_3 (Conv2DQuantiz (None, 20, 20, 16) 4659
eWrapper)
relu_3 (ReLU) (None, 20, 20, 16) 0
quant_conv_4 (Conv2DQuantiz (None, 18, 18, 8) 1179
eWrapper)
relu_4 (ReLU) (None, 18, 18, 8) 0
max_pool_0 (MaxPooling2D) (None, 9, 9, 8) 0
flatten_0 (Flatten) (None, 648) 0
quant_dense_0 (DenseQuantiz (None, 100) 65103
eWrapper)
relu_5 (ReLU) (None, 100) 0
quant_dense_1 (DenseQuantiz (None, 10) 1033
eWrapper)
=================================================================
Total params: 164,791
Trainable params: 164,058
Non-trainable params: 733
_________________________________________________________________
Let’s check the quantized model’s accuracy immediately after Q/DQ nodes are inserted.
# Compile quantized model
quantized_model.compile(
optimizer=tf.keras.optimizers.Adam(0.0001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)
# Get accuracy immediately after QDQ nodes are inserted.
_, q_aware_model_accuracy = quantized_model.evaluate(test_images, test_labels, verbose=0)
print("Quantization test accuracy immediately after QDQ insertion:", q_aware_model_accuracy)
Quantization test accuracy immediately after QDQ insertion: 0.883899986743927
The model’s accuracy decreases a bit as soon as Q/DQ nodes are inserted, requiring fine-tuning to recover it.
Note
Since this is a very small model, accuracy drop is small. For standard models like ResNets, accuracy drop immediately after QDQ insertion can be significant.
3. Fine-tune¶
Since the quantized model behaves similar to the original keras model, the same training recipe can be used for fine-tuning as well.
We fine-tune the model for two epochs and evaluate the model on the test dataset.
# fine tune quantized model for 2 epochs.
quantized_model.fit(
train_images, train_labels, batch_size=32, epochs=2, validation_split=0.1
)
# Get quantized accuracy
_, q_aware_model_accuracy_finetuned = quantized_model.evaluate(test_images, test_labels, verbose=0)
print("Quantization test accuracy after fine-tuning:", q_aware_model_accuracy_finetuned)
print("Baseline test accuracy (for reference):", baseline_model_accuracy)
Epoch 1/2
1688/1688 [==============================] - 26s 15ms/step - loss: 0.1793 - accuracy: 0.9340 - val_loss: 0.2468 - val_accuracy: 0.9112
Epoch 2/2
1688/1688 [==============================] - 25s 15ms/step - loss: 0.1725 - accuracy: 0.9373 - val_loss: 0.2484 - val_accuracy: 0.9070
Quantization test accuracy after fine-tuning: 0.9075999855995178
Baseline test accuracy (for reference): 0.888700008392334
Note
If the network is not fully converged, the fine-tuned model’s accuracy can surpass the original model’s accuracy.
# save TF INT8 original model
tf.keras.models.save_model(quantized_model, assets.example.int8_saved_model)
# Convert INT8 model to ONNX
utils.convert_saved_model_to_onnx(saved_model_dir = assets.example.int8_saved_model, onnx_model_path = assets.example.int8_onnx_model)
tf.keras.backend.clear_session()
WARNING:absl:Found untraced functions such as conv_0_layer_call_fn, conv_0_layer_call_and_return_conditional_losses, conv_1_layer_call_fn, conv_1_layer_call_and_return_conditional_losses, conv_2_layer_call_fn while saving (showing 5 of 14). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: GettingStarted/example/int8/saved_model/assets
INFO:tensorflow:Assets written to: GettingStarted/example/int8/saved_model/assets
ONNX conversion Done!
In this example, accuracy loss due to quantization is recovered in just two epochs.
This NVIDIA® Quantization Toolkit provides an easy interface to create quantized networks, and thus take advantage of INT8 inference on NVIDIA® GPUs using TensorRT™.