Partial Network QuantizationΒΆ

This example shows how the NVIDIA TensorFlow 2.x Quantization Toolkit can be used to quantize only a few layers in a TensorFlow 2.x model.

Goal

  1. Take a resnet-like model and train on cifar10 dataset.

  2. Quantize only layers named 'conv2d_2' and 'dense' in the model.

  3. Fine-tune to recover model accuracy.

  4. Save both original and quantized model while performing ONNX conversion.

Background

Few/specific layers to quantize are passed to quantize_model function as an object of QuantizationSpec class. quantization mode is set to partial in quantize_model function.

Adding layers with single input to QuantizationSpec is rather simple. However, for multi-input layers, flexibility to quantize specific inputs is also provided.

For example, user wants to quantize layers with name conv2d_2 and add.

Default behavior of Add layer class is NOT to quantize any input. None of inputs to the Add class layer is quantized when following code snippet is used.

q_spec = QuantizationSpec()
layer_name = ['conv2d_2']
q_spec.add(name=layer_name, quantization_index=layer_quantization_index)

q_model = quantize_model(model, quantization_spec=q_spec)

However, when layer of Add class is passed via QuantizationSpec object, all inputs are quantized.

q_spec = QuantizationSpec()
layer_name = ['conv2d_2', 'add']
q_spec.add(name=layer_name, quantization_index=layer_quantization_index)

q_model = quantize_model(model, quantization_spec=q_spec)

Code to quantize input at specific index (in this case, 1) for layer add could look as follows.

q_spec = QuantizationSpec()
layer_name = ['conv2d_2', 'add']
layer_quantization_index = [None, [1]]
q_spec.add(name=layer_name, quantization_index=layer_quantization_index)

q_model = quantize_model(model, quantization_spec=q_spec)

Layer name can be found from the putput of model.summary() function for Functional and Sequetial models. For subclassed model, use KerasModelTravller class from tensorflow_quantization.utils.

Refer Python API documentation for more details.


#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import tensorflow as tf
from tensorflow_quantization import quantize_model, QuantizationSpec
from tensorflow_quantization.custom_qdq_cases import ResNetV1QDQCase
import tiny_resnet
import os
from tensorflow_quantization import utils

tf.keras.backend.clear_session()

# Create folders to save TF and ONNX models
assets = utils.CreateAssetsFolders(os.path.join(os.getcwd(), "tutorials"))
assets.add_folder("simple_network_quantize_partial")
# Load CIFAR10 dataset
cifar10 = tf.keras.datasets.cifar10
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
nn_model_original = tiny_resnet.model()
tf.keras.utils.plot_model(nn_model_original, to_file = assets.simple_network_quantize_partial.fp32 + "/model.png")
../_images/c6145eb72b72d0715585f27d4701b984531e31f001919584c2240f13522e1af4.png
# Train original classification model
nn_model_original.compile(
    optimizer=tiny_resnet.optimizer(lr=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

nn_model_original.fit(
    train_images, train_labels, batch_size=32, epochs=10, validation_split=0.1
)

# Get baseline model accuracy
_, baseline_model_accuracy = nn_model_original.evaluate(
    test_images, test_labels, verbose=0
)
baseline_model_accuracy = round(100 * baseline_model_accuracy, 2)
print("Baseline FP32 model test accuracy:", baseline_model_accuracy)
Epoch 1/10
1407/1407 [==============================] - 18s 10ms/step - loss: 1.7617 - accuracy: 0.3615 - val_loss: 1.5624 - val_accuracy: 0.4300
Epoch 2/10
1407/1407 [==============================] - 14s 10ms/step - loss: 1.4876 - accuracy: 0.4645 - val_loss: 1.4242 - val_accuracy: 0.4762
Epoch 3/10
1407/1407 [==============================] - 14s 10ms/step - loss: 1.3737 - accuracy: 0.5092 - val_loss: 1.3406 - val_accuracy: 0.5202
Epoch 4/10
1407/1407 [==============================] - 14s 10ms/step - loss: 1.2952 - accuracy: 0.5396 - val_loss: 1.2768 - val_accuracy: 0.5398
Epoch 5/10
1407/1407 [==============================] - 14s 10ms/step - loss: 1.2370 - accuracy: 0.5649 - val_loss: 1.2466 - val_accuracy: 0.5560
Epoch 6/10
1407/1407 [==============================] - 14s 10ms/step - loss: 1.1857 - accuracy: 0.5812 - val_loss: 1.2052 - val_accuracy: 0.5718
Epoch 7/10
1407/1407 [==============================] - 14s 10ms/step - loss: 1.1442 - accuracy: 0.5972 - val_loss: 1.1836 - val_accuracy: 0.5786
Epoch 8/10
1407/1407 [==============================] - 14s 10ms/step - loss: 1.1079 - accuracy: 0.6091 - val_loss: 1.1356 - val_accuracy: 0.5978
Epoch 9/10
1407/1407 [==============================] - 14s 10ms/step - loss: 1.0702 - accuracy: 0.6220 - val_loss: 1.1244 - val_accuracy: 0.5940
Epoch 10/10
1407/1407 [==============================] - 14s 10ms/step - loss: 1.0354 - accuracy: 0.6368 - val_loss: 1.1019 - val_accuracy: 0.6108
Baseline FP32 model test accuracy: 61.46
# save TF FP32 original model
tf.keras.models.save_model(nn_model_original, assets.simple_network_quantize_partial.fp32_saved_model)

# Convert FP32 model to ONNX
utils.convert_saved_model_to_onnx(saved_model_dir = assets.simple_network_quantize_partial.fp32_saved_model, onnx_model_path = assets.simple_network_quantize_partial.fp32_onnx_model)
INFO:tensorflow:Assets written to: /home/sagar/nvidia/2021/Customers/Waymo/QAT/tensorrt_qat/docs/source/notebooks/tutorials/simple_network_quantize_partial/fp32/saved_model/assets
WARNING:tensorflow:From /home/sagar/miniconda3/lib/python3.8/site-packages/tf2onnx/tf_loader.py:711: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
ONNX conversion Done!
# Quantize model
# 1.1 Create a dictionary to quantize only two layers named 'conv2d_2' and 'dense'
qspec = QuantizationSpec()
layer_name = ['conv2d_2', 'dense']
qspec.add(name=layer_name)
# 1.2 Call quantize model function
q_nn_model = quantize_model(
    model=nn_model_original, quantization_mode="partial", quantization_spec=qspec)

q_nn_model.compile(
    optimizer=tiny_resnet.optimizer(lr=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)
[I] Layer `conv2d` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
[I] Layer `re_lu` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
[I] Layer `conv2d_1` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
[I] Layer `re_lu_1` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
[I] Layer `re_lu_2` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
[I] Layer `conv2d_3` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
[I] Layer `re_lu_3` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
[I] Layer `conv2d_4` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
[I] Layer `re_lu_4` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
[I] Layer `conv2d_5` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
[I] Layer `conv2d_6` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
[I] Layer `re_lu_5` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
[I] Layer `flatten` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
[I] Layer `re_lu_6` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
[I] Layer `dense_1` is not quantized. Partial quantization is enabled and layer name is not in user provided QuantizationSpec class object
tf.keras.utils.plot_model(q_nn_model, to_file = assets.simple_network_quantize_partial.int8 + "/model.png")
../_images/ad18d31f087e1724d5d5b4bbcda74ef0bf964df59a1db1b8a370b1f84f205b2c.png
_, q_model_accuracy = q_nn_model.evaluate(test_images, test_labels, verbose=0)
q_model_accuracy = round(100 * q_model_accuracy, 2)
print(
    "Test accuracy immediately after quantization:{}, diff:{}".format(
        q_model_accuracy, (baseline_model_accuracy - q_model_accuracy)
    )
)
Test accuracy immediately after quantization:58.96, diff:2.5
# Fine-tune quantized model
fine_tune_epochs = 2
q_nn_model.fit(
    train_images,
    train_labels,
    batch_size=32,
    epochs=fine_tune_epochs,
    validation_split=0.1,
)
_, q_model_accuracy = q_nn_model.evaluate(test_images, test_labels, verbose=0)
q_model_accuracy = round(100 * q_model_accuracy, 2)
print(
    "Accuracy after fine tuning for {} epochs :{}".format(
        fine_tune_epochs, q_model_accuracy
    )
)
Epoch 1/2
1407/1407 [==============================] - 20s 14ms/step - loss: 1.0074 - accuracy: 0.6480 - val_loss: 1.0854 - val_accuracy: 0.6194
Epoch 2/2
1407/1407 [==============================] - 19s 14ms/step - loss: 0.9799 - accuracy: 0.6583 - val_loss: 1.0782 - val_accuracy: 0.6242
Accuracy after fine tuning for 2 epochs :62.0
# Save TF INT8 original model
tf.keras.models.save_model(q_nn_model, assets.simple_network_quantize_partial.int8_saved_model)

# Convert INT8 model to ONNX
utils.convert_saved_model_to_onnx(saved_model_dir = assets.simple_network_quantize_partial.int8_saved_model, onnx_model_path = assets.simple_network_quantize_partial.int8_onnx_model)

tf.keras.backend.clear_session()
WARNING:absl:Found untraced functions such as conv2d_2_layer_call_fn, conv2d_2_layer_call_and_return_conditional_losses, dense_layer_call_fn, dense_layer_call_and_return_conditional_losses while saving (showing 4 of 4). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /home/sagar/nvidia/2021/Customers/Waymo/QAT/tensorrt_qat/docs/source/notebooks/tutorials/simple_network_quantize_partial/int8/saved_model/assets
INFO:tensorflow:Assets written to: /home/sagar/nvidia/2021/Customers/Waymo/QAT/tensorrt_qat/docs/source/notebooks/tutorials/simple_network_quantize_partial/int8/saved_model/assets
ONNX conversion Done!

Note

ONNX files can be visualized with Netron.