{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ "# **ResNet50 V1**\n", "\n", "This assumes that our toolkits and its base requirements have been met, including access to the ImageNet dataset. Please refer to [\"Requirements\"](https://gitlab-master.nvidia.com/sagshelke/tensorrt_qat/-/tree/main/examples#requirements) in the `examples` folder." ] }, { "cell_type": "markdown", "source": [ "## 1. Initial settings" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 10, "outputs": [], "source": [ "import os\n", "import tensorflow as tf\n", "from tensorflow_quantization.quantize import quantize_model\n", "from tensorflow_quantization.custom_qdq_cases import ResNetV1QDQCase\n", "from tensorflow_quantization.utils import convert_saved_model_to_onnx" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "HYPERPARAMS = {\n", " \"tfrecord_data_dir\": \"/media/Data/ImageNet/train-val-tfrecord\",\n", " \"batch_size\": 64,\n", " \"epochs\": 2,\n", " \"steps_per_epoch\": 500,\n", " \"train_data_size\": None,\n", " \"val_data_size\": None,\n", " \"save_root_dir\": \"./weights/resnet_50v1_jupyter\"\n", "}" ] }, { "cell_type": "markdown", "source": [ "### Load data" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from examples.data.data_loader import load_data\n", "train_batches, val_batches = load_data(HYPERPARAMS, model_name=\"resnet_v1\")" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ "## 2. Baseline model" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ "### Instantiate" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "model = tf.keras.applications.ResNet50(\n", " include_top=True,\n", " weights=\"imagenet\",\n", " classes=1000,\n", " classifier_activation=\"softmax\",\n", ")" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ "### Evaluate" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "781/781 [==============================] - 41s 51ms/step - loss: 1.0481 - accuracy: 0.7504\n", "Baseline val accuracy: 75.044%\n" ] } ], "source": [ "def compile_model(model, lr=0.001):\n", " model.compile(\n", " optimizer=tf.keras.optimizers.SGD(learning_rate=lr),\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n", " metrics=[\"accuracy\"],\n", " )\n", "\n", "compile_model(model)\n", "_, baseline_model_accuracy = model.evaluate(val_batches)\n", "print(\"Baseline val accuracy: {:.3f}%\".format(baseline_model_accuracy*100))" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ "### Save and convert to ONNX" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: ./weights/resnet_50v1_jupyter/saved_model_baseline/assets\n", "ONNX conversion Done!\n" ] } ], "source": [ "model_save_path = os.path.join(HYPERPARAMS[\"save_root_dir\"], \"saved_model_baseline\")\n", "model.save(model_save_path)\n", "convert_saved_model_to_onnx(saved_model_dir=model_save_path,\n", " onnx_model_path=model_save_path + \".onnx\")" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ "## 3. Quantization-Aware Training model" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ "### Quantize" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "q_model = quantize_model(model, custom_qdq_cases=[ResNetV1QDQCase()])" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ "### Fine-tune" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "500/500 [==============================] - 425s 838ms/step - loss: 0.4075 - accuracy: 0.8898 - val_loss: 1.0451 - val_accuracy: 0.7497\n", "Epoch 2/2\n", "500/500 [==============================] - 420s 840ms/step - loss: 0.3960 - accuracy: 0.8918 - val_loss: 1.0392 - val_accuracy: 0.7511\n" ] }, { "data": { "text/plain": "" }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "compile_model(q_model)\n", "q_model.fit(\n", " train_batches,\n", " validation_data=val_batches,\n", " batch_size=HYPERPARAMS[\"batch_size\"],\n", " steps_per_epoch=HYPERPARAMS[\"steps_per_epoch\"],\n", " epochs=HYPERPARAMS[\"epochs\"]\n", ")" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ "### Evaluate" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "781/781 [==============================] - 179s 229ms/step - loss: 1.0392 - accuracy: 0.7511\n", "QAT val accuracy: 75.114%\n" ] } ], "source": [ "_, qat_model_accuracy = q_model.evaluate(val_batches)\n", "print(\"QAT val accuracy: {:.3f}%\".format(qat_model_accuracy*100))" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ "### Save and convert to ONNX" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as conv1_conv_layer_call_fn, conv1_conv_layer_call_and_return_conditional_losses, conv2_block1_1_conv_layer_call_fn, conv2_block1_1_conv_layer_call_and_return_conditional_losses, conv2_block1_2_conv_layer_call_fn while saving (showing 5 of 140). These functions will not be directly callable after loading.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: ./weights/resnet_50v1_jupyter/saved_model_qat/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: ./weights/resnet_50v1_jupyter/saved_model_qat/assets\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "ONNX conversion Done!\n" ] } ], "source": [ "q_model_save_path = os.path.join(HYPERPARAMS[\"save_root_dir\"], \"saved_model_qat\")\n", "q_model.save(q_model_save_path)\n", "convert_saved_model_to_onnx(saved_model_dir=q_model_save_path,\n", " onnx_model_path=q_model_save_path + \".onnx\")" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ "## 4. QAT vs Baseline comparison" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Baseline vs QAT: 75.044% vs 75.114%\n", "Accuracy difference of +0.070%\n" ] } ], "source": [ "print(\"Baseline vs QAT: {:.3f}% vs {:.3f}%\".format(baseline_model_accuracy*100, qat_model_accuracy*100))\n", "\n", "acc_diff = (qat_model_accuracy - baseline_model_accuracy)*100\n", "acc_diff_sign = \"\" if acc_diff == 0 else (\"-\" if acc_diff < 0 else \"+\")\n", "print(\"Accuracy difference of {}{:.3f}%\".format(acc_diff_sign, abs(acc_diff)))" ] }, { "cell_type": "markdown", "source": [ "```{note}\n", "\n", "For full workflow, including TensorRT(TM) deployment, please refer to [examples/resnet](https://github.com/NVIDIA/TensorRT/tree/main/tools/tensorflow-quantization/examples/resnet).\n", "\n", "```" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 1 }