Full Network QuantizationΒΆ

In this tutorial, we will take a sample network with ResNet-like network and perform full network quantization.

Goal

  1. Take a resnet-like model and train on cifar10 dataset.

  2. Perform full model quantization.

  3. Fine-tune to recover model accuracy.

  4. Save both original and quantized model while performing ONNX conversion.


#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import tensorflow as tf
from tensorflow_quantization import quantize_model
import tiny_resnet
from tensorflow_quantization import utils
import os

tf.keras.backend.clear_session()

# Create folders to save TF and ONNX models
assets = utils.CreateAssetsFolders(os.path.join(os.getcwd(), "tutorials"))
assets.add_folder("simple_network_quantize_full")
# Load CIFAR10 dataset
cifar10 = tf.keras.datasets.cifar10
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
nn_model_original = tiny_resnet.model()
tf.keras.utils.plot_model(nn_model_original, to_file = assets.simple_network_quantize_full.fp32 + "/model.png")
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.
# Train original classification model
nn_model_original.compile(
    optimizer=tiny_resnet.optimizer(lr=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

nn_model_original.fit(
    train_images, train_labels, batch_size=32, epochs=10, validation_split=0.1
)

# Get baseline model accuracy
_, baseline_model_accuracy = nn_model_original.evaluate(
    test_images, test_labels, verbose=0
)
baseline_model_accuracy = round(100 * baseline_model_accuracy, 2)
print("Baseline FP32 model test accuracy: {}".format(baseline_model_accuracy))
Epoch 1/10
1407/1407 [==============================] - 16s 9ms/step - loss: 1.7653 - accuracy: 0.3622 - val_loss: 1.5516 - val_accuracy: 0.4552
Epoch 2/10
1407/1407 [==============================] - 13s 9ms/step - loss: 1.4578 - accuracy: 0.4783 - val_loss: 1.3877 - val_accuracy: 0.5042
Epoch 3/10
1407/1407 [==============================] - 13s 9ms/step - loss: 1.3499 - accuracy: 0.5193 - val_loss: 1.3066 - val_accuracy: 0.5342
Epoch 4/10
1407/1407 [==============================] - 13s 9ms/step - loss: 1.2736 - accuracy: 0.5486 - val_loss: 1.2636 - val_accuracy: 0.5550
Epoch 5/10
1407/1407 [==============================] - 13s 9ms/step - loss: 1.2101 - accuracy: 0.5732 - val_loss: 1.2121 - val_accuracy: 0.5670
Epoch 6/10
1407/1407 [==============================] - 12s 9ms/step - loss: 1.1559 - accuracy: 0.5946 - val_loss: 1.1753 - val_accuracy: 0.5844
Epoch 7/10
1407/1407 [==============================] - 13s 9ms/step - loss: 1.1079 - accuracy: 0.6101 - val_loss: 1.1143 - val_accuracy: 0.6076
Epoch 8/10
1407/1407 [==============================] - 13s 9ms/step - loss: 1.0660 - accuracy: 0.6272 - val_loss: 1.0965 - val_accuracy: 0.6158
Epoch 9/10
1407/1407 [==============================] - 13s 9ms/step - loss: 1.0271 - accuracy: 0.6392 - val_loss: 1.1100 - val_accuracy: 0.6122
Epoch 10/10
1407/1407 [==============================] - 13s 9ms/step - loss: 0.9936 - accuracy: 0.6514 - val_loss: 1.0646 - val_accuracy: 0.6304
Baseline FP32 model test accuracy: 61.65
# Save TF FP32 original model
tf.keras.models.save_model(nn_model_original, assets.simple_network_quantize_full.fp32_saved_model)

# Convert FP32 model to ONNX
utils.convert_saved_model_to_onnx(saved_model_dir = assets.simple_network_quantize_full.fp32_saved_model, onnx_model_path = assets.simple_network_quantize_full.fp32_onnx_model)

# Quantize model
q_nn_model = quantize_model(model=nn_model_original)
q_nn_model.compile(
    optimizer=tiny_resnet.optimizer(lr=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)
INFO:tensorflow:Assets written to: /home/nvidia/PycharmProjects/tensorrt_qat/docs/source/notebooks/tutorials/simple_network_quantize_full/fp32/saved_model/assets
WARNING:tensorflow:From /home/nvidia/PycharmProjects/tensorrt_qat/venv38_tf2.8_newPR/lib/python3.8/site-packages/tf2onnx/tf_loader.py:711: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
ONNX conversion Done!
_, q_model_accuracy = q_nn_model.evaluate(test_images, test_labels, verbose=0)
q_model_accuracy = round(100 * q_model_accuracy, 2)

print(
    "Test accuracy immediately after quantization: {}, diff: {}".format(
        q_model_accuracy, (baseline_model_accuracy - q_model_accuracy)
    )
)
Test accuracy immediately after quantization:50.45, diff:11.199999999999996
tf.keras.utils.plot_model(q_nn_model, to_file = assets.simple_network_quantize_full.int8 + "/model.png")
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.
# Fine-tune quantized model
fine_tune_epochs = 2

q_nn_model.fit(
    train_images,
    train_labels,
    batch_size=32,
    epochs=fine_tune_epochs,
    validation_split=0.1,
)

_, q_model_accuracy = q_nn_model.evaluate(test_images, test_labels, verbose=0)
q_model_accuracy = round(100 * q_model_accuracy, 2)
print(
    "Accuracy after fine-tuning for {} epochs: {}".format(
        fine_tune_epochs, q_model_accuracy
    )
)
print("Baseline accuracy (for reference): {}".format(baseline_model_accuracy))
Epoch 1/2
1407/1407 [==============================] - 27s 19ms/step - loss: 0.9625 - accuracy: 0.6630 - val_loss: 1.0430 - val_accuracy: 0.6420
Epoch 2/2
1407/1407 [==============================] - 25s 18ms/step - loss: 0.9315 - accuracy: 0.6758 - val_loss: 1.0717 - val_accuracy: 0.6336
Accuracy after fine-tuning for 2 epochs: 62.27
Baseline accuracy (for reference): 61.65
# Save TF INT8 original model
tf.keras.models.save_model(q_nn_model, assets.simple_network_quantize_full.int8_saved_model)

# Convert INT8 model to ONNX
utils.convert_saved_model_to_onnx(saved_model_dir = assets.simple_network_quantize_full.int8_saved_model, onnx_model_path = assets.simple_network_quantize_full.int8_onnx_model)

tf.keras.backend.clear_session()
WARNING:absl:Found untraced functions such as conv2d_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_1_layer_call_fn, conv2d_1_layer_call_and_return_conditional_losses, conv2d_2_layer_call_fn while saving (showing 5 of 18). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /home/nvidia/PycharmProjects/tensorrt_qat/docs/source/notebooks/tutorials/simple_network_quantize_full/int8/saved_model/assets
INFO:tensorflow:Assets written to: /home/nvidia/PycharmProjects/tensorrt_qat/docs/source/notebooks/tutorials/simple_network_quantize_full/int8/saved_model/assets
ONNX conversion Done!

Note

ONNX files can be visualized with Netron.