tensorflow_quantization.quantize.quantize_model(model, quantization_mode: str = 'full', quantization_spec: tensorflow_quantization.quantize.QuantizationSpec = None, custom_qdq_cases: List[CustomQDQInsertionCase] = None) keras.engine.training.Model[source]

Insert Q/DQ nodes in Keras model and return a copy. Weights are preserved unlike native keras clone.

  • model (tf.keras.Model) -- Keras Functional or Sequential model.subclassed models are not yet supported.

  • quantization_mode (str) -- quantization mode can be either 'full' or 'partial'

  • quantization_spec (QuantizationSpec) -- object of QuantizationSpec class. If few layers or layer classes are to be treated differently, LayerConfig class objects for that layer/layer class are created internally and added to QuantizationSpec class.

  • custom_qdq_cases (List[CustomQDQInsertionCase]) -- Case method on every object in this list is called by passing model and user passed quantization_spec as arguments. Each member of this list is an object of a class inherited from CustomQDQInsertionCase class.

  • AssertionError -- When passed model is subclassed.

  • AssertionError -- When CustomQDQInsertionCase does not return QuantizationSpec object.

  • AssertionError -- When quantization mode is partial but QuantizationSpec object is not passed.

  • AssertionError -- When quantization wrapper is not found for desired layer class.

  • ExceptionError -- When internal quantization class ID can't be detected. This happens when passed parameters do not make sense.


tf.keras.Model -- Quantized model with QDQ nodes inserted according to NVIDIA quantization recipe.


Currently only Functional and Sequential models are supported.


import tensorflow as tf
from tensorflow_quantization.quantize import quantize_model

# Simple full model quantization.
# 1. Create a simple network
input_img = tf.keras.layers.Input(shape=(28, 28))
r = tf.keras.layers.Reshape(target_shape=(28, 28, 1))(input_img)
x = tf.keras.layers.Conv2D(filters=2, kernel_size=(3, 3))(r)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Conv2D(filters=2, kernel_size=(3, 3))(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Flatten()(x)
model = tf.keras.Model(input_img, x)


# 2. Quantize the network
q_model = quantize_model(model)