ResNet50 V1

This assumes that our toolkits and its base requirements have been met, including access to the ImageNet dataset. Please refer to “Requirements” in the examples folder.

1. Initial settings

import os
import tensorflow as tf
from tensorflow_quantization.quantize import quantize_model
from tensorflow_quantization.custom_qdq_cases import ResNetV1QDQCase
from tensorflow_quantization.utils import convert_saved_model_to_onnx
HYPERPARAMS = {
    "tfrecord_data_dir": "/media/Data/ImageNet/train-val-tfrecord",
    "batch_size": 64,
    "epochs": 2,
    "steps_per_epoch": 500,
    "train_data_size": None,
    "val_data_size": None,
    "save_root_dir": "./weights/resnet_50v1_jupyter"
}

Load data

from examples.data.data_loader import load_data
train_batches, val_batches = load_data(HYPERPARAMS, model_name="resnet_v1")

2. Baseline model

Instantiate

model = tf.keras.applications.ResNet50(
    include_top=True,
    weights="imagenet",
    classes=1000,
    classifier_activation="softmax",
)

Evaluate

def compile_model(model, lr=0.001):
    model.compile(
        optimizer=tf.keras.optimizers.SGD(learning_rate=lr),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=["accuracy"],
    )

compile_model(model)
_, baseline_model_accuracy = model.evaluate(val_batches)
print("Baseline val accuracy: {:.3f}%".format(baseline_model_accuracy*100))
781/781 [==============================] - 41s 51ms/step - loss: 1.0481 - accuracy: 0.7504
Baseline val accuracy: 75.044%

Save and convert to ONNX

model_save_path = os.path.join(HYPERPARAMS["save_root_dir"], "saved_model_baseline")
model.save(model_save_path)
convert_saved_model_to_onnx(saved_model_dir=model_save_path,
                            onnx_model_path=model_save_path + ".onnx")
INFO:tensorflow:Assets written to: ./weights/resnet_50v1_jupyter/saved_model_baseline/assets
ONNX conversion Done!

3. Quantization-Aware Training model

Quantize

q_model = quantize_model(model, custom_qdq_cases=[ResNetV1QDQCase()])

Fine-tune

compile_model(q_model)
q_model.fit(
    train_batches,
    validation_data=val_batches,
    batch_size=HYPERPARAMS["batch_size"],
    steps_per_epoch=HYPERPARAMS["steps_per_epoch"],
    epochs=HYPERPARAMS["epochs"]
)
Epoch 1/2
500/500 [==============================] - 425s 838ms/step - loss: 0.4075 - accuracy: 0.8898 - val_loss: 1.0451 - val_accuracy: 0.7497
Epoch 2/2
500/500 [==============================] - 420s 840ms/step - loss: 0.3960 - accuracy: 0.8918 - val_loss: 1.0392 - val_accuracy: 0.7511
<keras.callbacks.History at 0x7f9cec1e60d0>

Evaluate

_, qat_model_accuracy = q_model.evaluate(val_batches)
print("QAT val accuracy: {:.3f}%".format(qat_model_accuracy*100))
781/781 [==============================] - 179s 229ms/step - loss: 1.0392 - accuracy: 0.7511
QAT val accuracy: 75.114%

Save and convert to ONNX

q_model_save_path = os.path.join(HYPERPARAMS["save_root_dir"], "saved_model_qat")
q_model.save(q_model_save_path)
convert_saved_model_to_onnx(saved_model_dir=q_model_save_path,
                            onnx_model_path=q_model_save_path + ".onnx")
WARNING:absl:Found untraced functions such as conv1_conv_layer_call_fn, conv1_conv_layer_call_and_return_conditional_losses, conv2_block1_1_conv_layer_call_fn, conv2_block1_1_conv_layer_call_and_return_conditional_losses, conv2_block1_2_conv_layer_call_fn while saving (showing 5 of 140). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: ./weights/resnet_50v1_jupyter/saved_model_qat/assets
INFO:tensorflow:Assets written to: ./weights/resnet_50v1_jupyter/saved_model_qat/assets
ONNX conversion Done!

4. QAT vs Baseline comparison

print("Baseline vs QAT: {:.3f}% vs {:.3f}%".format(baseline_model_accuracy*100, qat_model_accuracy*100))

acc_diff = (qat_model_accuracy - baseline_model_accuracy)*100
acc_diff_sign = "" if acc_diff == 0 else ("-" if acc_diff < 0 else "+")
print("Accuracy difference of {}{:.3f}%".format(acc_diff_sign, abs(acc_diff)))
Baseline vs QAT: 75.044% vs 75.114%
Accuracy difference of +0.070%

Note

For full workflow, including TensorRT™ deployment, please refer to examples/resnet.