{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "(GS_ETE)=\n", "\n", "# **Getting Started: End to End**\n", "\n", "NVIDIA(R) TensorFlow 2.x Quantization toolkit provides a simple API to quantize a given Keras model. At a higher level, Quantization Aware Training (QAT) is a three-step workflow as shown below:\n", "\n", "```{eval-rst}\n", ".. mermaid::\n", "\n", " flowchart LR\n", " id1(Pre-trained model) --> id2(Quantize) --> id3(Fine-tune)\n", "\n", "```\n", "Initially, the network is trained on the target dataset until fully converged. The Quantization step consists of inserting Q/DQ nodes in the pre-trained network to simulate quantization during training. Note that simply adding Q/DQ nodes will result in reduced accuracy since the quantization parameters are not yet updated for the given model. The network is then re-trained for a few epochs to recover accuracy in a step called \"fine-tuning\".\n", "\n", "```{eval-rst}\n", "\n", ".. admonition:: Goal\n", " :class: note\n", "\n", " #. Train a simple network on the Fashion MNIST dataset and save it as the baseline model.\n", " #. Quantize the pre-trained baseline network.\n", " #. Fine-tune the quantized network to recover accuracy and save it as the QAT model.\n", "\n", "```\n", "---\n", "\n", "## 1. Train\n", "Import required libraries and create a simple network with convolution and dense layers." ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"original\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " nn_input (InputLayer) [(None, 28, 28)] 0 \n", " \n", " reshape_0 (Reshape) (None, 28, 28, 1) 0 \n", " \n", " conv_0 (Conv2D) (None, 26, 26, 126) 1260 \n", " \n", " relu_0 (ReLU) (None, 26, 26, 126) 0 \n", " \n", " conv_1 (Conv2D) (None, 24, 24, 64) 72640 \n", " \n", " relu_1 (ReLU) (None, 24, 24, 64) 0 \n", " \n", " conv_2 (Conv2D) (None, 22, 22, 32) 18464 \n", " \n", " relu_2 (ReLU) (None, 22, 22, 32) 0 \n", " \n", " conv_3 (Conv2D) (None, 20, 20, 16) 4624 \n", " \n", " relu_3 (ReLU) (None, 20, 20, 16) 0 \n", " \n", " conv_4 (Conv2D) (None, 18, 18, 8) 1160 \n", " \n", " relu_4 (ReLU) (None, 18, 18, 8) 0 \n", " \n", " max_pool_0 (MaxPooling2D) (None, 9, 9, 8) 0 \n", " \n", " flatten_0 (Flatten) (None, 648) 0 \n", " \n", " dense_0 (Dense) (None, 100) 64900 \n", " \n", " relu_5 (ReLU) (None, 100) 0 \n", " \n", " dense_1 (Dense) (None, 10) 1010 \n", " \n", "=================================================================\n", "Total params: 164,058\n", "Trainable params: 164,058\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "import tensorflow as tf\n", "from tensorflow_quantization import quantize_model\n", "from tensorflow_quantization import utils\n", "\n", "assets = utils.CreateAssetsFolders(\"GettingStarted\")\n", "assets.add_folder(\"example\")\n", "\n", "def simple_net():\n", " \"\"\"\n", " Return a simple neural network.\n", " \"\"\"\n", " input_img = tf.keras.layers.Input(shape=(28, 28), name=\"nn_input\")\n", " x = tf.keras.layers.Reshape(target_shape=(28, 28, 1), name=\"reshape_0\")(input_img)\n", " x = tf.keras.layers.Conv2D(filters=126, kernel_size=(3, 3), name=\"conv_0\")(x)\n", " x = tf.keras.layers.ReLU(name=\"relu_0\")(x)\n", " x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), name=\"conv_1\")(x)\n", " x = tf.keras.layers.ReLU(name=\"relu_1\")(x)\n", " x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), name=\"conv_2\")(x)\n", " x = tf.keras.layers.ReLU(name=\"relu_2\")(x)\n", " x = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3), name=\"conv_3\")(x)\n", " x = tf.keras.layers.ReLU(name=\"relu_3\")(x)\n", " x = tf.keras.layers.Conv2D(filters=8, kernel_size=(3, 3), name=\"conv_4\")(x)\n", " x = tf.keras.layers.ReLU(name=\"relu_4\")(x)\n", " x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name=\"max_pool_0\")(x)\n", " x = tf.keras.layers.Flatten(name=\"flatten_0\")(x)\n", " x = tf.keras.layers.Dense(100, name=\"dense_0\")(x)\n", " x = tf.keras.layers.ReLU(name=\"relu_5\")(x)\n", " x = tf.keras.layers.Dense(10, name=\"dense_1\")(x)\n", " return tf.keras.Model(input_img, x, name=\"original\")\n", "\n", "# create model\n", "model = simple_net()\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Load Fashion MNIST data and split train and test sets." ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Load Fashion MNIST dataset\n", "mnist = tf.keras.datasets.fashion_mnist\n", "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n", "\n", "# Normalize the input image so that each pixel value is between 0 to 1.\n", "train_images = train_images / 255.0\n", "test_images = test_images / 255.0 " ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Compile the model and train for five epochs." ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n", "422/422 [==============================] - 4s 8ms/step - loss: 0.5639 - accuracy: 0.7920 - val_loss: 0.4174 - val_accuracy: 0.8437\n", "Epoch 2/5\n", "422/422 [==============================] - 3s 8ms/step - loss: 0.3619 - accuracy: 0.8696 - val_loss: 0.4134 - val_accuracy: 0.8433\n", "Epoch 3/5\n", "422/422 [==============================] - 3s 8ms/step - loss: 0.3165 - accuracy: 0.8855 - val_loss: 0.3137 - val_accuracy: 0.8812\n", "Epoch 4/5\n", "422/422 [==============================] - 3s 8ms/step - loss: 0.2787 - accuracy: 0.8964 - val_loss: 0.2943 - val_accuracy: 0.8890\n", "Epoch 5/5\n", "422/422 [==============================] - 3s 8ms/step - loss: 0.2552 - accuracy: 0.9067 - val_loss: 0.2857 - val_accuracy: 0.8952\n", "Baseline test accuracy: 0.888700008392334\n" ] } ], "source": [ "# Train original classification model\n", "model.compile(\n", " optimizer=\"adam\",\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=[\"accuracy\"],\n", ")\n", "\n", "model.fit(\n", " train_images, train_labels, batch_size=128, epochs=5, validation_split=0.1\n", ")\n", "\n", "# get baseline model accuracy\n", "_, baseline_model_accuracy = model.evaluate(\n", " test_images, test_labels, verbose=0\n", ")\n", "print(\"Baseline test accuracy:\", baseline_model_accuracy)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: GettingStarted/example/fp32/saved_model/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: GettingStarted/example/fp32/saved_model/assets\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "ONNX conversion Done!\n" ] } ], "source": [ "# save TF FP32 original model\n", "tf.keras.models.save_model(model, assets.example.fp32_saved_model)\n", "\n", "# Convert FP32 model to ONNX\n", "utils.convert_saved_model_to_onnx(saved_model_dir = assets.example.fp32_saved_model, onnx_model_path = assets.example.fp32_onnx_model)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 2. Quantize\n", "\n", "Full model quantization is the most basic quantization mode someone can follow. In this mode, Q/DQ nodes are inserted in all supported keras layers, with a single function call:" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Quantize model\n", "quantized_model = quantize_model(model)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Keras model summary shows all supported layers wrapped into QDQ wrapper class." ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"original\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " nn_input (InputLayer) [(None, 28, 28)] 0 \n", " \n", " reshape_0 (Reshape) (None, 28, 28, 1) 0 \n", " \n", " quant_conv_0 (Conv2DQuantiz (None, 26, 26, 126) 1515 \n", " eWrapper) \n", " \n", " relu_0 (ReLU) (None, 26, 26, 126) 0 \n", " \n", " quant_conv_1 (Conv2DQuantiz (None, 24, 24, 64) 72771 \n", " eWrapper) \n", " \n", " relu_1 (ReLU) (None, 24, 24, 64) 0 \n", " \n", " quant_conv_2 (Conv2DQuantiz (None, 22, 22, 32) 18531 \n", " eWrapper) \n", " \n", " relu_2 (ReLU) (None, 22, 22, 32) 0 \n", " \n", " quant_conv_3 (Conv2DQuantiz (None, 20, 20, 16) 4659 \n", " eWrapper) \n", " \n", " relu_3 (ReLU) (None, 20, 20, 16) 0 \n", " \n", " quant_conv_4 (Conv2DQuantiz (None, 18, 18, 8) 1179 \n", " eWrapper) \n", " \n", " relu_4 (ReLU) (None, 18, 18, 8) 0 \n", " \n", " max_pool_0 (MaxPooling2D) (None, 9, 9, 8) 0 \n", " \n", " flatten_0 (Flatten) (None, 648) 0 \n", " \n", " quant_dense_0 (DenseQuantiz (None, 100) 65103 \n", " eWrapper) \n", " \n", " relu_5 (ReLU) (None, 100) 0 \n", " \n", " quant_dense_1 (DenseQuantiz (None, 10) 1033 \n", " eWrapper) \n", " \n", "=================================================================\n", "Total params: 164,791\n", "Trainable params: 164,058\n", "Non-trainable params: 733\n", "_________________________________________________________________\n" ] } ], "source": [ "quantized_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Let's check the quantized model's accuracy immediately after Q/DQ nodes are inserted." ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Quantization test accuracy immediately after QDQ insertion: 0.883899986743927\n" ] } ], "source": [ "# Compile quantized model\n", "quantized_model.compile(\n", " optimizer=tf.keras.optimizers.Adam(0.0001),\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=[\"accuracy\"],\n", ")\n", "# Get accuracy immediately after QDQ nodes are inserted.\n", "_, q_aware_model_accuracy = quantized_model.evaluate(test_images, test_labels, verbose=0)\n", "print(\"Quantization test accuracy immediately after QDQ insertion:\", q_aware_model_accuracy)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "The model's accuracy decreases a bit as soon as Q/DQ nodes are inserted, requiring fine-tuning to recover it.\n", "\n", "```{note}\n", "\n", "Since this is a very small model, accuracy drop is small. For standard models like ResNets, accuracy drop immediately after QDQ insertion can be significant.\n", "\n", "```\n", "\n", "## 3. Fine-tune\n", "Since the quantized model behaves similar to the original keras model, the same training recipe can be used for fine-tuning as well.\n", "\n", "We fine-tune the model for two epochs and evaluate the model on the test dataset." ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "1688/1688 [==============================] - 26s 15ms/step - loss: 0.1793 - accuracy: 0.9340 - val_loss: 0.2468 - val_accuracy: 0.9112\n", "Epoch 2/2\n", "1688/1688 [==============================] - 25s 15ms/step - loss: 0.1725 - accuracy: 0.9373 - val_loss: 0.2484 - val_accuracy: 0.9070\n", "Quantization test accuracy after fine-tuning: 0.9075999855995178\n", "Baseline test accuracy (for reference): 0.888700008392334\n" ] } ], "source": [ "# fine tune quantized model for 2 epochs.\n", "quantized_model.fit(\n", " train_images, train_labels, batch_size=32, epochs=2, validation_split=0.1\n", ")\n", "# Get quantized accuracy\n", "_, q_aware_model_accuracy_finetuned = quantized_model.evaluate(test_images, test_labels, verbose=0)\n", "print(\"Quantization test accuracy after fine-tuning:\", q_aware_model_accuracy_finetuned)\n", "print(\"Baseline test accuracy (for reference):\", baseline_model_accuracy)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "```{note}\n", "\n", "If the network is not fully converged, the fine-tuned model's accuracy can surpass the original model's accuracy.\n", "\n", "```" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as conv_0_layer_call_fn, conv_0_layer_call_and_return_conditional_losses, conv_1_layer_call_fn, conv_1_layer_call_and_return_conditional_losses, conv_2_layer_call_fn while saving (showing 5 of 14). 