Source code for tensorflow_quantization.quantize_wrapper_base

#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#    
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import tensorflow_quantization.quantizers as quantizers
import tensorflow_quantization.global_config as cfg
from abc import abstractmethod

deserialize_keras_object = tf.keras.utils.deserialize_keras_object
serialize_keras_object = tf.keras.utils.serialize_keras_object

NO_WEIGHT_LAYERS = {
    "Concatenate",
    "Add",
    "AveragePooling2D",
    "GlobalAveragePooling2D",
    "MaxPooling2D",
    "BatchNormalization",
}


[docs]class BaseQuantizeWrapper(tf.keras.layers.Wrapper): """Base wrapper class which all layer wrappers inherit""" CHILD_WRAPPERS = {} def __init_subclass__(cls, **kwargs) -> None: super().__init_subclass__(**kwargs) cls.CHILD_WRAPPERS[cls.__name__] = cls
[docs] def __init__(self, layer: tf.keras.layers.Layer, **kwargs): """Create a quantize emulate wrapper for a keras layer. This wrapper provides options to quantize inputs and weights of the layer. Args: layer (tf.keras.layers.Layer): The keras layer to be quantized. **kwargs: Additional keyword arguments to be passed to the keras layer. """ if layer is None: raise ValueError("`layer` cannot be None.") # Check against keras.Model since it is an instance of keras.layers.Layer. if not isinstance(layer, tf.keras.layers.Layer) or isinstance( layer, tf.keras.Model ): raise ValueError( "`layer` can only be a `tf.keras.layers.Layer` instance. " "You passed an instance of type: {input}.".format( input=layer.__class__.__name__ ) ) if "name" not in kwargs: kwargs["name"] = self._make_layer_name(layer) super(BaseQuantizeWrapper, self).__init__(layer, **kwargs) # get quantize config object that holds all the information about how quantization should be performed. quantize_config_object = cfg.get_config_object() # set all initial quantization parameters to False/None self.quantize_inputs = False self.quantize_weights = False self.quantize_specific_input_indices = None layer_class_name_t = layer.__class__.__name__ # Layer class name layer_name_t = layer.name # Actual layer name def _configure_singular_quantize(): self.quantize_inputs = True if layer_class_name_t in NO_WEIGHT_LAYERS: self.quantize_weights = False else: self.quantize_weights = True def _configure_special_quantize( quantize_bool_list: list, layer_name_t: str, index_list_if_any: list = None ): assert (len(quantize_bool_list)) == 2, ( "Three boolean values (representing whether to quantize [inputs, weights]) must be provided in " "quantize_config for layer: {layer_name_t}. If quantization does not apply for specific part, " "pass None. e.g. For layer ( e.g. Concatenate, Add) with no weights, `qbool_list` to quantize " "input can be [True, False]".format(layer_name_t=layer_name_t) ) self.quantize_inputs = quantize_bool_list[0] if layer_class_name_t in NO_WEIGHT_LAYERS: self.quantize_weights = False else: self.quantize_weights = quantize_bool_list[1] if index_list_if_any: self.quantize_specific_input_indices = index_list_if_any if quantize_config_object.config_class_id == 0: # This is straight forward full network quantization _configure_singular_quantize() else: # Config class id 1 or 2. # User has provided layer (name) specific quantization information quantize_config_dict = quantize_config_object.get_layer_config() if layer_name_t in quantize_config_dict: # This layer needs to be quantized in specific way if "qindex_list" in quantize_config_dict[layer_name_t]: _configure_special_quantize( quantize_config_dict[layer_name_t]["qbool_list"], layer_name_t, quantize_config_dict[layer_name_t]["qindex_list"], ) else: _configure_special_quantize( quantize_config_dict[layer_name_t]["qbool_list"], layer_name_t ) else: _configure_singular_quantize() self._track_trackable(layer, name="layer")
@staticmethod def _make_layer_name(layer): return "{}_{}".format("quant", layer.name) @staticmethod def _weight_name(name): """Extracts the weight name from the full TensorFlow variable name. For example, returns 'kernel' for 'dense_2/kernel:0'. Args: name: TensorFlow variable name. Returns: Extracted weight name. """ return name.split(":")[0].split("/")[-1]
[docs] def build(self, input_shape): super(BaseQuantizeWrapper, self).build(input_shape) self.optimizer_step = self.add_weight( "optimizer_step", initializer=tf.keras.initializers.Constant(-1), dtype=tf.dtypes.int32, trainable=False, )
[docs] def compute_output_shape(self, input_shape): return self.layer.compute_output_shape(self.layer.input_shape)
def _last_value_quantizer( self, x, training, quantizer_vars, per_channel=False, channel_axis=-1 ): """Use currying to return True/False specialized fns to the cond.""" from tensorflow_quantization import G_NUM_BITS, G_SYMMETRIC, G_NARROW_RANGE return quantizers.LastValueQuantize( x, quantizer_vars["min_var"], quantizer_vars["max_var"], per_channel=per_channel, channel_axis=channel_axis, is_training=training, num_bits=G_NUM_BITS, narrow_range=G_NARROW_RANGE, symmetric=G_SYMMETRIC, )
[docs] @abstractmethod def call(self, inputs, training=None): raise NotImplementedError
[docs] def get_config(self): base_config = super(BaseQuantizeWrapper, self).get_config() config = {"quantize_config": None} return dict(list(base_config.items()) + list(config.items()))
[docs] @classmethod def from_config(cls, config): config = config.copy() # BaseQuantizeWrapper may be constructed with any QuantizeConfig and the # wrapper itself cannot know all the possible config classes. # The deserialization code should ensure the QuantizeConfig is in keras # serialization scope. quantize_config = deserialize_keras_object( config.pop("quantize_config"), module_objects=globals(), custom_objects=None ) layer = tf.keras.layers.deserialize(config.pop("layer")) return cls(layer=layer, quantize_config=quantize_config, **config)
@property def trainable(self): return self.layer.trainable @trainable.setter def trainable(self, value): self.layer.trainable = value @property def trainable_weights(self): return self.layer.trainable_weights + self._trainable_weights @property def non_trainable_weights(self): return self.layer.non_trainable_weights + self._non_trainable_weights @property def updates(self): return self.layer.updates + self._updates @property def losses(self): return self.layer.losses + self._losses