{ "cells": [ { "cell_type": "markdown", "source": [ "(tut_one)=\n", "\n", "# **Full Network Quantization** \n", "\n", "In this tutorial, we will take a sample network with ResNet-like network and perform ``full`` network quantization.\n", "\n", "\n", "```{eval-rst}\n", "\n", ".. admonition:: Goal\n", " :class: note\n", "\n", " #. Take a resnet-like model and train on cifar10 dataset.\n", " #. Perform full model quantization.\n", " #. Fine-tune to recover model accuracy.\n", " #. Save both original and quantized model while performing ONNX conversion.\n", "\n", "```\n", "---" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "#\n", "# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n", "# SPDX-License-Identifier: Apache-2.0\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "#\n", "\n", "import tensorflow as tf\n", "from tensorflow_quantization import quantize_model\n", "import tiny_resnet\n", "from tensorflow_quantization import utils\n", "import os\n", "\n", "tf.keras.backend.clear_session()\n", "\n", "# Create folders to save TF and ONNX models\n", "assets = utils.CreateAssetsFolders(os.path.join(os.getcwd(), \"tutorials\"))\n", "assets.add_folder(\"simple_network_quantize_full\")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "# Load CIFAR10 dataset\n", "cifar10 = tf.keras.datasets.cifar10\n", "(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()\n", "\n", "# Normalize the input image so that each pixel value is between 0 and 1.\n", "train_images = train_images / 255.0\n", "test_images = test_images / 255.0" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.\n" ] } ], "source": [ "nn_model_original = tiny_resnet.model()\n", "tf.keras.utils.plot_model(nn_model_original, to_file = assets.simple_network_quantize_full.fp32 + \"/model.png\")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n", "1407/1407 [==============================] - 16s 9ms/step - loss: 1.7653 - accuracy: 0.3622 - val_loss: 1.5516 - val_accuracy: 0.4552\n", "Epoch 2/10\n", "1407/1407 [==============================] - 13s 9ms/step - loss: 1.4578 - accuracy: 0.4783 - val_loss: 1.3877 - val_accuracy: 0.5042\n", "Epoch 3/10\n", "1407/1407 [==============================] - 13s 9ms/step - loss: 1.3499 - accuracy: 0.5193 - val_loss: 1.3066 - val_accuracy: 0.5342\n", "Epoch 4/10\n", "1407/1407 [==============================] - 13s 9ms/step - loss: 1.2736 - accuracy: 0.5486 - val_loss: 1.2636 - val_accuracy: 0.5550\n", "Epoch 5/10\n", "1407/1407 [==============================] - 13s 9ms/step - loss: 1.2101 - accuracy: 0.5732 - val_loss: 1.2121 - val_accuracy: 0.5670\n", "Epoch 6/10\n", "1407/1407 [==============================] - 12s 9ms/step - loss: 1.1559 - accuracy: 0.5946 - val_loss: 1.1753 - val_accuracy: 0.5844\n", "Epoch 7/10\n", "1407/1407 [==============================] - 13s 9ms/step - loss: 1.1079 - accuracy: 0.6101 - val_loss: 1.1143 - val_accuracy: 0.6076\n", "Epoch 8/10\n", "1407/1407 [==============================] - 13s 9ms/step - loss: 1.0660 - accuracy: 0.6272 - val_loss: 1.0965 - val_accuracy: 0.6158\n", "Epoch 9/10\n", "1407/1407 [==============================] - 13s 9ms/step - loss: 1.0271 - accuracy: 0.6392 - val_loss: 1.1100 - val_accuracy: 0.6122\n", "Epoch 10/10\n", "1407/1407 [==============================] - 13s 9ms/step - loss: 0.9936 - accuracy: 0.6514 - val_loss: 1.0646 - val_accuracy: 0.6304\n", "Baseline FP32 model test accuracy: 61.65\n" ] } ], "source": [ "# Train original classification model\n", "nn_model_original.compile(\n", " optimizer=tiny_resnet.optimizer(lr=1e-4),\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=[\"accuracy\"],\n", ")\n", "\n", "nn_model_original.fit(\n", " train_images, train_labels, batch_size=32, epochs=10, validation_split=0.1\n", ")\n", "\n", "# Get baseline model accuracy\n", "_, baseline_model_accuracy = nn_model_original.evaluate(\n", " test_images, test_labels, verbose=0\n", ")\n", "baseline_model_accuracy = round(100 * baseline_model_accuracy, 2)\n", "print(\"Baseline FP32 model test accuracy: {}\".format(baseline_model_accuracy))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /home/nvidia/PycharmProjects/tensorrt_qat/docs/source/notebooks/tutorials/simple_network_quantize_full/fp32/saved_model/assets\n", "WARNING:tensorflow:From /home/nvidia/PycharmProjects/tensorrt_qat/venv38_tf2.8_newPR/lib/python3.8/site-packages/tf2onnx/tf_loader.py:711: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use `tf.compat.v1.graph_util.extract_sub_graph`\n", "ONNX conversion Done!\n" ] } ], "source": [ "# Save TF FP32 original model\n", "tf.keras.models.save_model(nn_model_original, assets.simple_network_quantize_full.fp32_saved_model)\n", "\n", "# Convert FP32 model to ONNX\n", "utils.convert_saved_model_to_onnx(saved_model_dir = assets.simple_network_quantize_full.fp32_saved_model, onnx_model_path = assets.simple_network_quantize_full.fp32_onnx_model)\n", "\n", "# Quantize model\n", "q_nn_model = quantize_model(model=nn_model_original)\n", "q_nn_model.compile(\n", " optimizer=tiny_resnet.optimizer(lr=1e-4),\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=[\"accuracy\"],\n", ")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy immediately after quantization:50.45, diff:11.199999999999996\n" ] } ], "source": [ "_, q_model_accuracy = q_nn_model.evaluate(test_images, test_labels, verbose=0)\n", "q_model_accuracy = round(100 * q_model_accuracy, 2)\n", "\n", "print(\n", " \"Test accuracy immediately after quantization: {}, diff: {}\".format(\n", " q_model_accuracy, (baseline_model_accuracy - q_model_accuracy)\n", " )\n", ")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 7, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.\n" ] } ], "source": [ "tf.keras.utils.plot_model(q_nn_model, to_file = assets.simple_network_quantize_full.int8 + \"/model.png\")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 8, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "1407/1407 [==============================] - 27s 19ms/step - loss: 0.9625 - accuracy: 0.6630 - val_loss: 1.0430 - val_accuracy: 0.6420\n", "Epoch 2/2\n", "1407/1407 [==============================] - 25s 18ms/step - loss: 0.9315 - accuracy: 0.6758 - val_loss: 1.0717 - val_accuracy: 0.6336\n", "Accuracy after fine-tuning for 2 epochs: 62.27\n", "Baseline accuracy (for reference): 61.65\n" ] } ], "source": [ "# Fine-tune quantized model\n", "fine_tune_epochs = 2\n", "\n", "q_nn_model.fit(\n", " train_images,\n", " train_labels,\n", " batch_size=32,\n", " epochs=fine_tune_epochs,\n", " validation_split=0.1,\n", ")\n", "\n", "_, q_model_accuracy = q_nn_model.evaluate(test_images, test_labels, verbose=0)\n", "q_model_accuracy = round(100 * q_model_accuracy, 2)\n", "print(\n", " \"Accuracy after fine-tuning for {} epochs: {}\".format(\n", " fine_tune_epochs, q_model_accuracy\n", " )\n", ")\n", "print(\"Baseline accuracy (for reference): {}\".format(baseline_model_accuracy))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 9, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as conv2d_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_1_layer_call_fn, conv2d_1_layer_call_and_return_conditional_losses, conv2d_2_layer_call_fn while saving (showing 5 of 18). These functions will not be directly callable after loading.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /home/nvidia/PycharmProjects/tensorrt_qat/docs/source/notebooks/tutorials/simple_network_quantize_full/int8/saved_model/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /home/nvidia/PycharmProjects/tensorrt_qat/docs/source/notebooks/tutorials/simple_network_quantize_full/int8/saved_model/assets\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "ONNX conversion Done!\n" ] } ], "source": [ "# Save TF INT8 original model\n", "tf.keras.models.save_model(q_nn_model, assets.simple_network_quantize_full.int8_saved_model)\n", "\n", "# Convert INT8 model to ONNX\n", "utils.convert_saved_model_to_onnx(saved_model_dir = assets.simple_network_quantize_full.int8_saved_model, onnx_model_path = assets.simple_network_quantize_full.int8_onnx_model)\n", "\n", "tf.keras.backend.clear_session()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "```{note}\n", "ONNX files can be visualized with [Netron](https://netron.app/).\n", "```" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } } ], "metadata": { "interpreter": { "hash": "4442e1c252d743d7d1ab28567e302ebe8a15da81acb5d7e7894db75e10bdb29d" }, "kernelspec": { "display_name": "Python 3.8.10 ('base')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }