ResNet50 V1¶
This assumes that our toolkits and its base requirements have been met, including access to the ImageNet dataset. Please refer to “Requirements” in the examples
folder.
1. Initial settings¶
import os
import tensorflow as tf
from tensorflow_quantization.quantize import quantize_model
from tensorflow_quantization.custom_qdq_cases import ResNetV1QDQCase
from tensorflow_quantization.utils import convert_saved_model_to_onnx
HYPERPARAMS = {
"tfrecord_data_dir": "/media/Data/ImageNet/train-val-tfrecord",
"batch_size": 64,
"epochs": 2,
"steps_per_epoch": 500,
"train_data_size": None,
"val_data_size": None,
"save_root_dir": "./weights/resnet_50v1_jupyter"
}
Load data¶
from examples.data.data_loader import load_data
train_batches, val_batches = load_data(HYPERPARAMS, model_name="resnet_v1")
2. Baseline model¶
Instantiate¶
model = tf.keras.applications.ResNet50(
include_top=True,
weights="imagenet",
classes=1000,
classifier_activation="softmax",
)
Evaluate¶
def compile_model(model, lr=0.001):
model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=lr),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=["accuracy"],
)
compile_model(model)
_, baseline_model_accuracy = model.evaluate(val_batches)
print("Baseline val accuracy: {:.3f}%".format(baseline_model_accuracy*100))
781/781 [==============================] - 41s 51ms/step - loss: 1.0481 - accuracy: 0.7504
Baseline val accuracy: 75.044%
Save and convert to ONNX¶
model_save_path = os.path.join(HYPERPARAMS["save_root_dir"], "saved_model_baseline")
model.save(model_save_path)
convert_saved_model_to_onnx(saved_model_dir=model_save_path,
onnx_model_path=model_save_path + ".onnx")
INFO:tensorflow:Assets written to: ./weights/resnet_50v1_jupyter/saved_model_baseline/assets
ONNX conversion Done!
3. Quantization-Aware Training model¶
Quantize¶
q_model = quantize_model(model, custom_qdq_cases=[ResNetV1QDQCase()])
Fine-tune¶
compile_model(q_model)
q_model.fit(
train_batches,
validation_data=val_batches,
batch_size=HYPERPARAMS["batch_size"],
steps_per_epoch=HYPERPARAMS["steps_per_epoch"],
epochs=HYPERPARAMS["epochs"]
)
Epoch 1/2
500/500 [==============================] - 425s 838ms/step - loss: 0.4075 - accuracy: 0.8898 - val_loss: 1.0451 - val_accuracy: 0.7497
Epoch 2/2
500/500 [==============================] - 420s 840ms/step - loss: 0.3960 - accuracy: 0.8918 - val_loss: 1.0392 - val_accuracy: 0.7511
<keras.callbacks.History at 0x7f9cec1e60d0>
Evaluate¶
_, qat_model_accuracy = q_model.evaluate(val_batches)
print("QAT val accuracy: {:.3f}%".format(qat_model_accuracy*100))
781/781 [==============================] - 179s 229ms/step - loss: 1.0392 - accuracy: 0.7511
QAT val accuracy: 75.114%
Save and convert to ONNX¶
q_model_save_path = os.path.join(HYPERPARAMS["save_root_dir"], "saved_model_qat")
q_model.save(q_model_save_path)
convert_saved_model_to_onnx(saved_model_dir=q_model_save_path,
onnx_model_path=q_model_save_path + ".onnx")
WARNING:absl:Found untraced functions such as conv1_conv_layer_call_fn, conv1_conv_layer_call_and_return_conditional_losses, conv2_block1_1_conv_layer_call_fn, conv2_block1_1_conv_layer_call_and_return_conditional_losses, conv2_block1_2_conv_layer_call_fn while saving (showing 5 of 140). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: ./weights/resnet_50v1_jupyter/saved_model_qat/assets
INFO:tensorflow:Assets written to: ./weights/resnet_50v1_jupyter/saved_model_qat/assets
ONNX conversion Done!
4. QAT vs Baseline comparison¶
print("Baseline vs QAT: {:.3f}% vs {:.3f}%".format(baseline_model_accuracy*100, qat_model_accuracy*100))
acc_diff = (qat_model_accuracy - baseline_model_accuracy)*100
acc_diff_sign = "" if acc_diff == 0 else ("-" if acc_diff < 0 else "+")
print("Accuracy difference of {}{:.3f}%".format(acc_diff_sign, abs(acc_diff)))
Baseline vs QAT: 75.044% vs 75.114%
Accuracy difference of +0.070%
Note
For full workflow, including TensorRT™ deployment, please refer to examples/resnet.