tensorflow_quantization.quantize_model¶
- tensorflow_quantization.quantize.quantize_model(model, quantization_mode: str = 'full', quantization_spec: tensorflow_quantization.quantize.QuantizationSpec = None, custom_qdq_cases: List[CustomQDQInsertionCase] = None) keras.engine.training.Model [source]¶
Insert Q/DQ nodes in Keras model and return a copy. Weights are preserved unlike native keras clone.
- Parameters
model (tf.keras.Model) -- Keras Functional or Sequential model.subclassed models are not yet supported.
quantization_mode (str) -- quantization mode can be either 'full' or 'partial'
quantization_spec (QuantizationSpec) -- object of QuantizationSpec class. If few layers or layer classes are to be treated differently, LayerConfig class objects for that layer/layer class are created internally and added to QuantizationSpec class.
custom_qdq_cases (List[CustomQDQInsertionCase]) -- Case method on every object in this list is called by passing model and user passed quantization_spec as arguments. Each member of this list is an object of a class inherited from CustomQDQInsertionCase class.
- Raises
AssertionError -- When passed model is subclassed.
AssertionError -- When CustomQDQInsertionCase does not return QuantizationSpec object.
AssertionError -- When quantization mode is partial but QuantizationSpec object is not passed.
AssertionError -- When quantization wrapper is not found for desired layer class.
ExceptionError -- When internal quantization class ID can't be detected. This happens when passed parameters do not make sense.
- Returns
tf.keras.Model -- Quantized model with QDQ nodes inserted according to NVIDIA quantization recipe.
Note
Currently only Functional and Sequential models are supported.
Examples
import tensorflow as tf
from tensorflow_quantization.quantize import quantize_model
# Simple full model quantization.
# 1. Create a simple network
input_img = tf.keras.layers.Input(shape=(28, 28))
r = tf.keras.layers.Reshape(target_shape=(28, 28, 1))(input_img)
x = tf.keras.layers.Conv2D(filters=2, kernel_size=(3, 3))(r)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Conv2D(filters=2, kernel_size=(3, 3))(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Flatten()(x)
model = tf.keras.Model(input_img, x)
print(model.summary())
# 2. Quantize the network
q_model = quantize_model(model)
print(q_model.summary())