pytorch-quantization’s documentation

Basic Functionalities

Quantization function

tensor_quant and fake_tensor_quant are 2 basic functions to quantize a tensor. fake_tensor_quant returns fake quantized tensor (float value). tensor_quant returns quantized tensor (integer value) and scale.

tensor_quant(inputs, amax, num_bits=8, output_dtype=torch.float, unsigned=False)
fake_tensor_quant(inputs, amax, num_bits=8, output_dtype=torch.float, unsigned=False)

Example:

from pytorch_quantization import tensor_quant

# Generate random input. With fixed seed 12345, x should be
# tensor([0.9817, 0.8796, 0.9921, 0.4611, 0.0832, 0.1784, 0.3674, 0.5676, 0.3376, 0.2119])
torch.manual_seed(12345)
x = torch.rand(10)

# fake quantize tensor x. fake_quant_x will be
# tensor([0.9843, 0.8828, 0.9921, 0.4609, 0.0859, 0.1797, 0.3672, 0.5703, 0.3359, 0.2109])
fake_quant_x = tensor_quant.fake_tensor_quant(x, x.abs().max())

# quantize tensor x. quant_x will be
# tensor([126., 113., 127.,  59.,  11.,  23.,  47.,  73.,  43.,  27.])
# with scale=128.0057
quant_x, scale = tensor_quant.tensor_quant(x, x.abs().max())

Backward of both functions are defined as Straight-Through Estimator (STE).

Descriptor and quantizer

QuantDescriptor defines how a tensor should be quantized. There are also some predefined QuantDescriptor, e.g. QUANT_DESC_8BIT_PER_TENSOR and QUANT_DESC_8BIT_CONV2D_WEIGHT_PER_CHANNEL.

TensorQuantizer is the module for quantizing tensors and defined by QuantDescriptor.

from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizer

quant_desc = QuantDescriptor(num_bits=4, fake_quant=False, axis=(0), unsigned=True)
quantizer = TensorQuantizer(quant_desc)

torch.manual_seed(12345)
x = torch.rand(10, 9, 8, 7)

quant_x = quantizer(x)

If amax is given in the QuantDescriptor, TensorQuantizer will use it to quantize. Otherwise, TensorQuantizer will compute amax then quantize. amax will be computed w.r.t axis specified. Note that axis of QuantDescriptor specify remaining axis as oppsed to axis of max().

Quantized module

There are 2 major types of module, Conv and Linear. Both can replace torch.nn version and apply quantization on both weight and activation.

Both take quant_desc_input and quant_desc_weight in addition to arguments of the original module.

from torch import nn

from pytorch_quantization import tensor_quant
import pytorch_quantization.nn as quant_nn

# pytorch's module
fc1 = nn.Linear(in_features, out_features, bias=True)
conv1 = nn.Conv2d(in_channels, out_channels, kernel_size)

# quantized version
quant_fc1 = quant_nn.Linear(
    in_features, out_features, bias=True,
    quant_desc_input=tensor_quant.QUANT_DESC_8BIT_PER_TENSOR,
    quant_desc_weight=tensor_quant.QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW)
quant_conv1 = quant_nn.Conv2d(
    in_channels, out_channels, kernel_size,
    quant_desc_input=tensor_quant.QUANT_DESC_8BIT_PER_TENSOR,
    quant_desc_weight=tensor_quant.QUANT_DESC_8BIT_CONV2D_WEIGHT_PER_CHANNEL)

Post training quantization

A model can be post training quantized by simply by calling quant_modules.initialize()

from pytorch_quantization import quant_modules
model = torchvision.models.resnet50()

If a model is not entirely defined by module, than TensorQuantizer should be manually created and added to the right place in the model.

Calibration

Calibration is the TensorRT terminology of passing data samples to the quantizer and deciding the best amax for activations. We support 3 calibration methods:

  • max: Simply use global maximum absolute value

  • entropy: TensorRT’s entropy calibration

  • percentile: Get rid of outlier based on given percentile.

  • mse: MSE(Mean Squared Error) based calibration

In above ResNet50 example, calibration method is set to mse, it can be used as the following example:

# Find the TensorQuantizer and enable calibration
for name, module in model.named_modules():
    if name.endswith('_quantizer'):
        module.enable_calib()
        module.disable_quant()  # Use full precision data to calibrate

# Feeding data samples
model(x)
# ...

# Finalize calibration
for name, module in model.named_modules():
    if name.endswith('_quantizer'):
        module.load_calib_amax()
        module.disable_calib()
        module.enable_quant()

# If running on GPU, it needs to call .cuda() again because new tensors will be created by calibration process
model.cuda()

# Keep running the quantized model
# ...

Note

Calibration needs to be performed before exporting the model to ONNX.

Quantization Aware Training

Quantization Aware Training is based on Straight Through Estimator (STE) derivative approximation. It is some time known as “quantization aware training”. We don’t use the name because it doesn’t reflect the underneath assumption. If anything, it makes training being “unaware” of quantization because of the STE approximation.

After calibration is done, Quantization Aware Training is simply select a training schedule and continue training the calibrated model. Usually, it doesn’t need to fine tune very long. We usually use around 10% of the original training schedule, starting at 1% of the initial training learning rate, and a cosine annealing learning rate schedule that follows the decreasing half of a cosine period, down to 1% of the initial fine tuning learning rate (0.01% of the initial training learning rate).

Some recommendations

Quantization Aware Training (Essentially a discrete numerical optimization problem) is not a solved problem mathematically. Based on our experience, here are some recommendations:

  • For STE approximation to work well, it is better to use small learning rate. Large learning rate is more likely to enlarge the variance introduced by STE approximation and destroy the trained network.

  • Do not change quantization representation (scale) during training, at least not too frequently. Changing scale every step, it is effectively like changing data format (e8m7, e5m10, e3m4, et.al) every step, which will easily affect convergence.

Export to ONNX

The goal of exporting to ONNX is to deploy to TensorRT, not to ONNX runtime. So we only export fake quantized model into a form TensorRT will take. Fake quantization will be broken into a pair of QuantizeLinear/DequantizeLinear ONNX ops. TensorRT will take the generated ONNX graph, and execute it in int8 in the most optimized way to its capability.

Note

Currently, we only support exporting int8 and fp8 fake quantized modules. Additionally, quantized modules need to be calibrated before exporting to ONNX.

Fake quantized model can be exported to ONNX as any other Pytorch model. Please learn more about exporting a Pytorch model to ONNX at torch.onnx. For example:

import pytorch_quantization
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules

quant_modules.initialize()
model = torchvision.models.resnet50()

# load the calibrated model
state_dict = torch.load("quant_resnet50-entropy-1024.pth", map_location="cpu")
model.load_state_dict(state_dict)
model.cuda()

dummy_input = torch.randn(128, 3, 224, 224, device='cuda')

input_names = [ "actual_input_1" ]
output_names = [ "output1" ]

with pytorch_quantization.enable_onnx_export():
     # enable_onnx_checker needs to be disabled. See notes below.
     torch.onnx.export(
         model, dummy_input, "quant_resnet50.onnx", verbose=True, opset_version=10, enable_onnx_checker=False
         )

Note

Note that axis is added to QuantizeLinear and DequantizeLinear in opset13.

Quantizing Resnet50

Create a quantized model

Import the necessary python modules:

import torch
import torch.utils.data
from torch import nn

from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor

from torchvision import models

sys.path.append("path to torchvision/references/classification/")
from train import evaluate, train_one_epoch, load_data

Adding quantized modules

The first step is to add quantizer modules to the neural network graph. This package provides a number of quantized layer modules, which contain quantizers for inputs and weights. e.g. quant_nn.QuantLinear, which can be used in place of nn.Linear. These quantized layers can be substituted automatically, via monkey-patching, or by manually modifying the model definition.

Automatic layer substitution is done with quant_modules. This should be called before model creation.

from pytorch_quantization import quant_modules
quant_modules.initialize()

This will apply to all instances of each module. If you do not want all modules to be quantized you should instead substitute the quantized modules manually. Stand-alone quantizers can also be added to the model with quant_nn.TensorQuantizer.

Post training quantization

For efficient inference, we want to select a fixed range for each quantizer. Starting with a pre-trained model, the simplest way to do this is by calibration.

Calibration

We will use histogram based calibration for activations and the default max calibration for weights.

quant_desc_input = QuantDescriptor(calib_method='histogram')
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

model = models.resnet50(pretrained=True)
model.cuda()

To collect activation histograms we must feed sample data in to the model. First, create ImageNet dataloaders as done in the training script. Then, enable calibration in each quantizer and feed training data in to the model. 1024 samples (2 batches of 512) should be sufficient to estimate the distribution of activations. Use training data for calibration so that validation also measures generalization of the selected ranges.

data_path = "PATH to imagenet"
batch_size = 512

traindir = os.path.join(data_path, 'train')
valdir = os.path.join(data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(traindir, valdir, False, False)

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size,
    sampler=train_sampler, num_workers=4, pin_memory=True)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=batch_size,
    sampler=test_sampler, num_workers=4, pin_memory=True)
 def collect_stats(model, data_loader, num_batches):
     """Feed data to the network and collect statistic"""

     # Enable calibrators
     for name, module in model.named_modules():
         if isinstance(module, quant_nn.TensorQuantizer):
             if module._calibrator is not None:
                 module.disable_quant()
                 module.enable_calib()
             else:
                 module.disable()

     for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
         model(image.cuda())
         if i >= num_batches:
             break

     # Disable calibrators
     for name, module in model.named_modules():
         if isinstance(module, quant_nn.TensorQuantizer):
             if module._calibrator is not None:
                 module.enable_quant()
                 module.disable_calib()
             else:
                 module.enable()

 def compute_amax(model, **kwargs):
     # Load calib result
     for name, module in model.named_modules():
         if isinstance(module, quant_nn.TensorQuantizer):
             if module._calibrator is not None:
                 if isinstance(module._calibrator, calib.MaxCalibrator):
                     module.load_calib_amax()
                 else:
                     module.load_calib_amax(**kwargs)
             print(F"{name:40}: {module}")
     model.cuda()

# It is a bit slow since we collect histograms on CPU
 with torch.no_grad():
     collect_stats(model, data_loader, num_batches=2)
     compute_amax(model, method="percentile", percentile=99.99)

After calibration is done, quantizers will have amax set, which represents the absolute maximum input value representable in the quantized space. By default, weight ranges are per channel while activation ranges are per tensor. We can see the condensed amaxes by printing each TensorQuantizer module.

conv1._input_quantizer                  : TensorQuantizer(8bit fake per-tensor amax=2.6400 calibrator=MaxCalibrator(track_amax=False) quant)
conv1._weight_quantizer                 : TensorQuantizer(8bit fake axis=(0) amax=[0.0000, 0.7817](64) calibrator=MaxCalibrator(track_amax=False) quant)
layer1.0.conv1._input_quantizer         : TensorQuantizer(8bit fake per-tensor amax=6.8645 calibrator=MaxCalibrator(track_amax=False) quant)
layer1.0.conv1._weight_quantizer        : TensorQuantizer(8bit fake axis=(0) amax=[0.0000, 0.7266](64) calibrator=MaxCalibrator(track_amax=False) quant)
...

Evaluate the calibrated model

Next we will evaluate the classification accuracy of our post training quantized model on the ImageNet validation set.

criterion = nn.CrossEntropyLoss()
with torch.no_grad():
    evaluate(model, criterion, data_loader_test, device="cuda", print_freq=20)

# Save the model
torch.save(model.state_dict(), "/tmp/quant_resnet50-calibrated.pth")

This should yield 76.1% top-1 accuracy, which is close to the pre-trained model accuracy of 76.2%.

Use different calibration

We can try different calibrations without recollecting the histograms, and see which one gets the best accuracy.

with torch.no_grad():
    compute_amax(model, method="percentile", percentile=99.9)
    evaluate(model, criterion, data_loader_test, device="cuda", print_freq=20)

with torch.no_grad():
    for method in ["mse", "entropy"]:
        print(F"{method} calibration")
        compute_amax(model, method=method)
        evaluate(model, criterion, data_loader_test, device="cuda", print_freq=20)

MSE and entropy should both get over 76%. 99.9% clips too many values for resnet50 and will get slightly lower accuracy.

Quantization Aware Training

Optionally, we can fine-tune the calibrated model to improve accuracy further.

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

# Training takes about one and half hour per epoch on a single V100
train_one_epoch(model, criterion, optimizer, data_loader, "cuda", 0, 100)

# Save the model
torch.save(model.state_dict(), "/tmp/quant_resnet50-finetuned.pth")

After one epoch of fine-tuning, we can achieve over 76.4% top-1 accuracy. Fine-tuning for more epochs with learning rate annealing can improve accuracy further. For example, fine-tuning for 15 epochs with cosine annealing starting with a learning rate of 0.001 can get over 76.7%. It should be noted that the same fine-tuning schedule will improve the accuracy of the unquantized model as well.

Further optimization

For efficient inference on TensorRT, we need know more details about the runtime optimization. TensorRT supports fusion of quantizing convolution and residual add. The new fused operator has two inputs. Let us call them conv-input and residual-input. Here the fused operator’s output precision must match the residual input precision. When there is another quantizing node after the fused operator, we can insert a pair of quantizing/dequantizing nodes between the residual-input and the Elementwise-Addition node, so that quantizing node after the Convolution node is fused with the Convolution node, and the Convolution node is completely quantized with INT8 input and output. We cannot use automatic monkey-patching to apply this optimization and we need to manually insert the quantizing/dequantizing nodes.

First create a copy of resnet.py from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py, modify the constructor, add explicit bool flag ‘quantize’

def resnet50(pretrained: bool = False, progress: bool = True, quantize: bool = False, **kwargs: Any) -> ResNet:
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, quantize, **kwargs)
def _resnet(arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], pretrained: bool, progress: bool,
            quantize: bool, **kwargs: Any) -> ResNet:
    model = ResNet(block, layers, quantize, **kwargs)
class ResNet(nn.Module):
    def __init__(self,
                 block: Type[Union[BasicBlock, Bottleneck]],
                 layers: List[int],
                 quantize: bool = False,
                 num_classes: int = 1000,
                 zero_init_residual: bool = False,
                 groups: int = 1,
                 width_per_group: int = 64,
                 replace_stride_with_dilation: Optional[List[bool]] = None,
                 norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
        super(ResNet, self).__init__()
        self._quantize = quantize

When this self._quantize flag is set to True, we need replace all the nn.Conv2d with quant_nn.QuantConv2d.

def conv3x3(in_planes: int,
            out_planes: int,
            stride: int = 1,
            groups: int = 1,
            dilation: int = 1,
            quantize: bool = False) -> nn.Conv2d:
    """3x3 convolution with padding"""
    if quantize:
        return quant_nn.QuantConv2d(in_planes,
                                    out_planes,
                                    kernel_size=3,
                                    stride=stride,
                                    padding=dilation,
                                    groups=groups,
                                    bias=False,
                                    dilation=dilation)
    else:
        return nn.Conv2d(in_planes,
                         out_planes,
                         kernel_size=3,
                         stride=stride,
                         padding=dilation,
                         groups=groups,
                         bias=False,
                         dilation=dilation)
  def conv1x1(in_planes: int, out_planes: int, stride: int = 1, quantize: bool = False) -> nn.Conv2d:
      """1x1 convolution"""
      if quantize:
          return quant_nn.QuantConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
      else:
          return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

The residual conv add can be find both in both BasicBlock and Bottleneck. We need first declare quantization node in the __init__ function.

def __init__(self,
             inplanes: int,
             planes: int,
             stride: int = 1,
             downsample: Optional[nn.Module] = None,
             groups: int = 1,
             base_width: int = 64,
             dilation: int = 1,
             norm_layer: Optional[Callable[..., nn.Module]] = None,
             quantize: bool = False) -> None:
    # other code...
    self._quantize = quantize
    if self._quantize:
        self.residual_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)

Finally we need patch the forward function in both BasicBlock and Bottleneck, inserting extra quantization/dequantization nodes here.

def forward(self, x: Tensor) -> Tensor:
    # other code...
    if self._quantize:
        out += self.residual_quantizer(identity)
    else:
        out += identity
    out = self.relu(out)

    return out

The final resnet code with residual quantized can be found in https://github.com/NVIDIA/TensorRT/blob/master/tools/pytorch-quantization/examples/torchvision/models/classification/resnet.py

Creating Custom Quantized Modules

There are several quantized modules provided by the quantization tool as follows:

  • QuantConv1d, QuantConv2d, QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, QuantConvTranspose3d

  • QuantLinear

  • QuantAvgPool1d, QuantAvgPool2d, QuantAvgPool3d, QuantMaxPool1d, QuantMaxPool2d, QuantMaxPool3d

To quantize a module, we need to quantize the input and weights if present. Following are 3 major use-cases:

  1. Create quantized wrapper for modules that have only inputs

  2. Create quantized wrapper for modules that have inputs as well as weights.

  3. Directly add the TensorQuantizer module to the inputs of an operation in the model graph.

The first two methods are very useful if it’s needed to automatically replace the original modules (nodes in the graph) with their quantized versions. The third method could be useful when it’s required to manually add the quantization to the model graph at very specific places (more manual, more control).

Let’s see each use-case with examples below.

Quantizing Modules With Only Inputs

A suitable example would be quantizing the pooling module variants.

Essentially, we need to provide a wrapper function that takes the original module and adds the TensorQuantizer module around it so that the input is first quantized and then fed into the original module.

  • Create the wrapper by subclassing the original module (pooling.MaxPool2d) along with the utilities module (_utils.QuantInputMixin).

class QuantMaxPool2d(pooling.MaxPool2d, _utils.QuantInputMixin):
  • The __init__.py function would call the original module’s init function and provide it with the corresponding arguments. There would be just one additional argument using **kwargs which contains the quantization configuration information. The QuantInputMixin utility contains the method pop_quant_desc_in_kwargs which extracts this configuration information from the input or returns a default if that input is None. Finally the init_quantizer method is called that initializes the TensorQuantizer module which would quantize the inputs.

def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
             return_indices=False, ceil_mode=False, **kwargs):
    super(QuantMaxPool2d, self).__init__(kernel_size, stride, padding, dilation,
                                         return_indices, ceil_mode)
    quant_desc_input = _utils.pop_quant_desc_in_kwargs(self.__class__, input_only=True, **kwargs)
    self.init_quantizer(quant_desc_input)
  • After the initialization, the forward function needs to be defined in our wrapper module that would actually quantize the inputs using the _input_quantizer that was initialized in the __init__ function forwarding the inputs to the base module using super call.

def forward(self, input):
    quant_input = self._input_quantizer(input)
    return super(QuantMaxPool2d, self).forward(quant_input)
  • Finally, we need to define a getter method for the _input_quantizer. This could, for example, be used to disable the quantization for a particular module using module.input_quantizer.disable() which is helpful while experimenting with different layer quantization configuration.

@property
def input_quantizer(self):
    return self._input_quantizer

A complete quantized pooling module would look like following:

class QuantMaxPool2d(pooling.MaxPool2d, _utils.QuantInputMixin):
    """Quantized 2D maxpool"""
    def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
                return_indices=False, ceil_mode=False, **kwargs):
        super(QuantMaxPool2d, self).__init__(kernel_size, stride, padding, dilation,
                                            return_indices, ceil_mode)
        quant_desc_input = _utils.pop_quant_desc_in_kwargs(self.__class__, input_only=True, **kwargs)
        self.init_quantizer(quant_desc_input)

    def forward(self, input):
        quant_input = self._input_quantizer(input)
        return super(QuantMaxPool2d, self).forward(quant_input)

    @property
    def input_quantizer(self):
        return self._input_quantizer

Quantizing Modules With Weights and Inputs

We give an example of quantizing the torch.nn.Linear module. It follows that the only additional change from the previous example of quantizing pooling modules is that we’d need to accomodate the quantization of weights in the Linear module.

  • We create the quantized linear module as follows:

class QuantLinear(nn.Linear, _utils.QuantMixin):
  • In the __init__ function, we first use the pop_quant_desc_in_kwargs function to extract the quantization descriptors for both inputs and weights. Second, we initialize the TensorQuantizer modules for both inputs and weights using these quantization descriptors.

def __init__(self, in_features, out_features, bias=True, **kwargs):
        super(QuantLinear, self).__init__(in_features, out_features, bias)
        quant_desc_input, quant_desc_weight = _utils.pop_quant_desc_in_kwargs(self.__class__, **kwargs)

        self.init_quantizer(quant_desc_input, quant_desc_weight)
  • Also, override the forward function call and pass the inputs and weights through _input_quantizer and _weight_quantizer respectively before passing the quantized arguments to the actual F.Linear call. This step adds the actual input/weight TensorQuantizer to the module and eventually the model.

def forward(self, input):
    quant_input = self._input_quantizer(input)
    quant_weight = self._weight_quantizer(self.weight)

    output = F.linear(quant_input, quant_weight, bias=self.bias)

    return output
  • Also similar to the Linear module, we add the getter methods for the TensorQuantizer modules associated with inputs/weights. This could be used to, for example, disable the quantization mechanism by calling module_obj.weight_quantizer.disable()

@property
def input_quantizer(self):
    return self._input_quantizer

@property
def weight_quantizer(self):
    return self._weight_quantizer
  • With all of the above changes, the quantized Linear module would look like following:

class QuantLinear(nn.Linear, _utils.QuantMixin):

    def __init__(self, in_features, out_features, bias=True, **kwargs):
        super(QuantLinear, self).__init__(in_features, out_features, bias)
        quant_desc_input, quant_desc_weight = _utils.pop_quant_desc_in_kwargs(self.__class__, **kwargs)

        self.init_quantizer(quant_desc_input, quant_desc_weight)

    def forward(self, input):
        quant_input = self._input_quantizer(input)
        quant_weight = self._weight_quantizer(self.weight)

        output = F.linear(quant_input, quant_weight, bias=self.bias)

        return output

    @property
    def input_quantizer(self):
        return self._input_quantizer

    @property
    def weight_quantizer(self):
        return self._weight_quantizer

Directly Quantizing Inputs In Graph

It is also possible to directly quantize graph inputs without creating wrappers as explained above.

Here’s an example:

test_input = torch.randn(1, 5, 5, 5, dtype=torch.double)

quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)

quant_input = quantizer(test_input)

out = F.adaptive_avg_pool2d(quant_input, 3)

Assume that there is a F.adaptive_avg_pool2d operation in the graph and we’d like to quantize this operation. In the example above, we use TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) to define a quantizer that we then use to actually quantize the test_input and then feed this quantized input to the F.adaptive_avg_pool2d operation. Note that this quantizer is the same as the ones we used earlier while created quantized versions of torch’s modules.

pytorch_quantization.calib

pytorch_quantization.calib provides Calibrator classes that collect data statistics and determine pytorch_quantization parameters.

MaxCalibrator

class pytorch_quantization.calib.MaxCalibrator(num_bits, axis, unsigned, track_amax=False)

Max calibrator, tracks the maximum value globally

Parameters:
  • calib_desc – A MaxCalibDescriptor.

  • num_bits – An integer. Number of bits of quantization.

  • axis – A tuple. see QuantDescriptor.

  • unsigned – A boolean. using unsigned quantization.

Readonly Properties:

amaxs: A list of amax. Numpy array is saved as it is likely to be used for some plot.

collect(x)

Tracks the absolute max of all tensors

Parameters:

x – A tensor

Raises:

RuntimeError – If amax shape changes

compute_amax()

Return the absolute max of all tensors collected

reset()

Reset the collected absolute max

HistogramCalibrator

class pytorch_quantization.calib.HistogramCalibrator(num_bits, axis, unsigned, num_bins=2048, grow_method=None, skip_zeros=False, torch_hist=False)

Unified histogram calibrator

Histogram will be only collected once. compute_amax() performs entropy, percentile, or mse

calibration based on arguments

Parameters:
  • num_bits – An integer. Number of bits of quantization.

  • axis – A tuple. see QuantDescriptor.

  • unsigned – A boolean. using unsigned quantization.

  • num_bins – An integer. Number of histograms bins. Default 2048.

  • grow_method – A string. DEPRECATED. default None.

  • skip_zeros – A boolean. If True, skips zeros when collecting data for histogram. Default False.

  • torch_hist – A boolean. If True, collect histogram by torch.histc instead of np.histogram. If input tensor is on GPU, histc will also be running on GPU. Default False.

collect(x)

Collect histogram

compute_amax(method: str, *, stride: int = 1, start_bin: int = 128, percentile: float = 99.99)

Compute the amax from the collected histogram

Parameters:

method – A string. One of [‘entropy’, ‘mse’, ‘percentile’]

Keyword Arguments:
  • stride – An integer. Default 1

  • start_bin – An integer. Default 128

  • percentils – A float number between [0, 100]. Default 99.99.

Returns:

amax – a tensor

reset()

Reset the collected histogram

pytorch_quantization.nn

TensorQuantizer

class pytorch_quantization.nn.TensorQuantizer(quant_desc=<pytorch_quantization.tensor_quant.ScaledQuantDescriptor object>, disabled=False, if_quant=True, if_clip=False, if_calib=False)

Tensor quantizer module

This module uses tensor_quant or fake_tensor_quant function to quantize a tensor. And wrappers variable, moving statistics we’d want when training a quantized network.

Experimental features:

clip stage learns range before enabling quantization. calib stage runs calibration

Parameters:
  • quant_desc – An instance of QuantDescriptor.

  • disabled – A boolean. If True, by pass the whole module returns input. Default False.

  • if_quant – A boolean. If True, run main quantization body. Default True.

  • if_clip – A boolean. If True, clip before quantization and learn amax. Default False.

  • if_calib – A boolean. If True, run calibration. Not implemented yet. Settings of calibration will probably go to QuantDescriptor.

Raises:

Readonly Properties:
  • axis:

  • fake_quant:

  • scale:

  • step_size:

Mutable Properties:
  • num_bits:

  • unsigned:

  • amax:

__init__(quant_desc=<pytorch_quantization.tensor_quant.ScaledQuantDescriptor object>, disabled=False, if_quant=True, if_clip=False, if_calib=False)

Initialize quantizer and set up required variables

disable()

Bypass the module

disable_clip()

Disable clip stage

enable_clip()

Enable clip stage

forward(inputs)

Apply tensor_quant function to inputs

Parameters:

inputs – A Tensor of type float32.

Returns:

outputs – A Tensor of type output_dtype

init_learn_amax()

Initialize learned amax from fixed amax

load_calib_amax(*args, **kwargs)

Load amax from calibrator.

Updates the amax buffer with value computed by the calibrator, creating it if necessary. *args and **kwargs are directly passed to compute_amax, except “strict” in kwargs. Refer to compute_amax for more details.

Quantized Modules

_QuantConvNd

class pytorch_quantization.nn.modules.quant_conv._QuantConvNd(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode, quant_desc_input, quant_desc_weight)

base class of quantized Conv inherited from _ConvNd

Comments of original arguments can be found in torch.nn.modules.conv

Parameters:
  • quant_desc_input – An instance of QuantDescriptor. Quantization descriptor of input.

  • quant_desc_weight – An instance of QuantDescriptor. Quantization descriptor of weight.

Raises:

ValueError – If unsupported arguments are passed in.

Readonly properties:
  • input_quantizer:

  • weight_quantizer:

Static methods:
  • set_default_quant_desc_input: Set default_quant_desc_input

  • set_default_quant_desc_weight: Set default_quant_desc_weight

QuantConv1d

class pytorch_quantization.nn.QuantConv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', **kwargs)

Quantized 1D Conv

QuantConv2d

class pytorch_quantization.nn.QuantConv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', **kwargs)

Quantized 2D conv

QuantConv3d

class pytorch_quantization.nn.QuantConv3d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', **kwargs)

Quantized 3D Conv

QuantConvTranspose1d

class pytorch_quantization.nn.QuantConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', **kwargs)

Quantized ConvTranspose1d

QuantConvTranspose2d

class pytorch_quantization.nn.QuantConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', **kwargs)

Quantized ConvTranspose2d

QuantConvTranspose3d

class pytorch_quantization.nn.QuantConvTranspose3d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', **kwargs)

Quantized ConvTranspose3d

QuantLinear

class pytorch_quantization.nn.QuantLinear(in_features, out_features, bias=True, **kwargs)

Quantized version of nn.Linear

Apply quantized linear to the incoming data, y = dequant(quant(x)quant(A)^T + b).

Keep Module name “Linear” instead of “QuantLinear” so that it can be easily dropped into preexisting model and load pretrained weights. An alias “QuantLinear” is defined below. The base code is a copy of nn.Linear, see detailed comment of original arguments there.

Quantization descriptors are passed in in kwargs. If not presents, default_quant_desc_input and default_quant_desc_weight are used.

Keyword Arguments:
  • quant_desc_input – An instance of QuantDescriptor. Quantization descriptor of input.

  • quant_desc_wegiht – An instance of QuantDescriptor. Quantization descriptor of weight.

Raises:
  • ValueError – If unsupported arguments are passed in.

  • KeyError – If unsupported kwargs are passed in.

Readonly properties:
  • input_quantizer:

  • weight_quantizer:

Static methods:
  • set_default_quant_desc_input: Set default_quant_desc_input

  • set_default_quant_desc_weight: Set default_quant_desc_weight

QuantMaxPool1d

class pytorch_quantization.nn.QuantMaxPool1d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, **kwargs)

Quantized 1D maxpool

QuantMaxPool2d

class pytorch_quantization.nn.QuantMaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, **kwargs)

Quantized 2D maxpool

QuantMaxPool3d

class pytorch_quantization.nn.QuantMaxPool3d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, **kwargs)

Quantized 3D maxpool

QuantAvgPool1d

class pytorch_quantization.nn.QuantAvgPool1d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, **kwargs)

Quantized 1D average pool

QuantAvgPool2d

class pytorch_quantization.nn.QuantAvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None, **kwargs)

Quantized 2D average pool

QuantAvgPool3d

class pytorch_quantization.nn.QuantAvgPool3d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None, **kwargs)

Quantized 3D average pool

QuantAdaptiveAvgPool1d

class pytorch_quantization.nn.QuantAdaptiveAvgPool1d(output_size, **kwargs)

Quantized 1D adaptive average pool

QuantAdaptiveAvgPool2d

class pytorch_quantization.nn.QuantAdaptiveAvgPool2d(output_size, **kwargs)

Quantized 2D adaptive average pool

QuantAdaptiveAvgPool3d

class pytorch_quantization.nn.QuantAdaptiveAvgPool3d(output_size, **kwargs)

Quantized 3D adaptive average pool

Clip

class pytorch_quantization.nn.Clip(clip_value_min, clip_value_max, learn_min=False, learn_max=False)

Clip tensor

Parameters:
  • clip_value_min – A number or tensor of lower bound to clip

  • clip_value_max – A number of tensor of upper bound to clip

  • learn_min – A boolean. If True, learn min. clip_value_min will be used to initialize. Default False

  • learn_max – A boolean. Similar as learn_min but for max.

Raises:

ValueError

QuantLSTM

class pytorch_quantization.nn.QuantLSTM(*args, **kwargs)

Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.

QuantLSTMCell

class pytorch_quantization.nn.QuantLSTMCell(input_size, hidden_size, bias=True, **kwargs)

A long short-term memory (LSTM) cell.

pytorch_quantization.nn.functional

Some supportive functions

ClipFunction

class pytorch_quantization.nn.functional.ClipFunction(*args, **kwargs)

An universal tensor clip function

Pytorch’s clamp() only supports scalar range and doesn’t support broadcast. This implementation uses min/max which is more genaral. The gradient is defined according to IBM’s PACT paper https://arxiv.org/abs/1805.06085, which is also the behavior of Tensorflow’s clip_by_value()

clip is alias of ClipFunction.apply

pytorch_quantization.optim.helper

Helper functions for quant optimizer/trainer

pytorch_quantization.optim.helper.freeze_parameters(model, patterns)

Set requires_grad to False if patterns match name

Parameters:
  • model – A Module

  • patterns – A list of strings that will be used to match parameter names. If parameter name contains any pattern, it will be frozen.

pytorch_quantization.optim.helper.group_parameters(model, patterns_list, lrs=None, momentums=None, weight_decays=None)

Group parameters for using per-parameters option in optimizer

Returns a list of dict that matches Pytorch optimizer fashion, see https://pytorch.org/docs/stable/optim.html#per-parameter-options for more details.

Example

>>> [
>>>    {'params': model.base.parameters()},
>>>    {'params': model.classifier.parameters(), 'lr': 1e-3}
>>> ]

Parameters will be grouped w.r.t first level of the keys_list. e.g. keys_list=[[‘conv1’, ‘conv2’], [‘conv3’]] will return 2 groups, one with conv1 and conv2 in name, and the other with conv3 in name.

If lr, momentum or weight_decay are supplied, they will be added to the group as well.

Parameters:
  • model – A module

  • patterns_list – A list of list of strings. WARNING: patters must be EXCLUSIVE, the function doesn’t perform exclusive check.

  • lrs – A list of float with same length as keys_list or None.

  • momentums – A list of float with same length as keys_list or None.

  • weight_decays – A list of float with same length as keys_list or None.

Returns:

param_group – A list of dict

pytorch_quantization.optim.helper.match_parameters(model, patterns)

Returns an generator over module parameters if name matches key

It is useful to group parameters, and apply different functions to different group. This function provides an easy way to group them.

Parameters:
  • model – A Module

  • patterns – A list of strings that will be used to match parameter names. If parameter name contains any pattern, it will be yield

Yields:

param – Module parameters

pytorch_quantization.optim.helper.quant_weight_inplace(model)

Make quantization inplace

Search for quantized modules including QuantConvNd and QuantLinear, make weight quantization in place using weight_quantizer.

Most publications of quantization aware training uses STE by default, which is really an approximation of derivative of the nondifferentiable quantization function, which works to some extended but by no means the F=ma of the problem. Inplace quantization can be used to implement relax-and-round, which is a common method in Discrete Optimization’s or Integer Programming.

pytorch_quantization.tensor_quant

Basic tensor quantization functions

QuantDescriptor

pytorch_quantization.tensor_quant.QuantDescriptor

alias of ScaledQuantDescriptor

ScaledQuantDescriptor

class pytorch_quantization.tensor_quant.ScaledQuantDescriptor(num_bits=8, name=None, **kwargs)

Supportive descriptor of quantization

Describe how a tensor should be quantized. A QuantDescriptor and a tensor defines a quantized tensor.

Parameters:
  • num_bits – An integer. Number of bits of quantization. It is used to calculate scaling factor. Default 8.

  • name – Seems a nice thing to have

Keyword Arguments:
  • fake_quant – A boolean. If True, use fake quantization mode. Default True.

  • axis – None, int or tuple of int. axes which will have its own max for computing scaling factor. If None (the default), use per tensor scale. Must be in the range [-rank(input_tensor), rank(input_tensor)). e.g. For a KCRS weight tensor, quant_axis=(0) will yield per channel scaling. Default None.

  • amax – A float or list/ndarray of floats of user specified absolute max range. If supplied, ignore quant_axis and use this to quantize. If learn_amax is True, will be used to initialize learnable amax. Default None.

  • learn_amax – A boolean. If True, learn amax. Default False.

  • scale_amax – A float. If supplied, multiply amax by scale_amax. Default None. It is useful for some quick experiment.

  • calib_method – A string. One of [“max”, “histogram”] indicates which calibration to use. Except the simple max calibration, other methods are all hisogram based. Default “max”.

  • unsigned – A Boolean. If True, use unsigned. Default False.

Raises:

TypeError – If unsupported type is passed in.

Read-only properties:
  • fake_quant:

  • name:

  • learn_amax:

  • scale_amax:

  • axis:

  • calib_method:

  • num_bits:

  • amax:

  • unsigned:

dict()

Serialize to dict

The build-in __dict__ method returns all the attributes, which includes those have default value and have protected prefix “_”. This method only returns those have values other than the default one and don’t have _ in key. Construct a instance by dict returned by this method should get exactly the same instance.

classmethod from_yaml(yaml_str)

Create descriptor from yaml str

to_yaml()

Create yaml serialization Some attributes need special treatment to have human readable form, including amax, axis.

TensorQuantFunction

class pytorch_quantization.tensor_quant.TensorQuantFunction(*args, **kwargs)

A universal tensor quantization function

Take an input tensor, output an quantized tensor. The granularity of scale can be interpreted from the shape of amax. output_dtype indicates whether the quantized value will be stored in integer or float. The reason we want to store it in float is the pytorch function takes the quantized value may not accept integer input, e.g. Conv2D.

It uses 2^num_bits -1 values instead of 2^num_bits. e.g., for num_bits=8, it uses [-127, 127] instead of [-128, 127]

static backward(ctx, grad_outputs, grad_scale)

Implements straight through estimation with clipping. For -amax <= input <= amax the gradient passes straight through, otherwise the gradient is zero.

Parameters:
  • ctx – A Context object with saved tensors from forward.

  • grad_outputs – A tensor of gradient of outputs.

  • grad_scale – A tensor of gradient of scale.

Returns:

grad_inputs – A tensor of gradient.

static forward(ctx, inputs, amax, num_bits=8, unsigned=False, narrow_range=True)

Follow tensorflow convention, max value is passed in and used to decide scale, instead of inputing scale directly. Though inputing scale directly may be more natural to use.

Parameters:
  • ctx – A Context object to store tensors for backward.

  • inputs – A Tensor of type float32.

  • amax – A Tensor of type float32. Inputs will be quantized within range [-amax, amax] amax will be broadcasted to inputs tensor.

  • num_bits – A integer used to calculate scaling factor, scale = (2^(num_bits-1) - 1) / max Effectively, it indicates how many integer bits is used to represent the value. Default 8.

  • output_dtype – A type of Tensor. torch.int32 or torch.float32.

  • unsigned – A boolean. Use unsigned integer range. E.g. [0, 255] for num_bits=8. Default False.

  • narrow_range – A boolean. Use symmetric integer range for signed quantization E.g. [-127,127] instead of [-128,127] for num_bits=8. Default True.

Returns:

outputs – A Tensor of type output_dtype. scale: A Tensor of type float32. outputs / scale will dequantize outputs tensor.

Raises:

ValueError

tensor_quant is alias of TensorQuantFunction.apply

fake_tensor_quant is alias of FakeTensorQuantFunction.apply

pytorch_quantization.utils

pytorch_quantization.utils.amp_wrapper

pytorch_quantization.utils.quant_logging

A WAR for codes that messes up logging format

pytorch_quantization.utils.quant_logging.reset_logger_handler()

Remove all handler in root logger

pytorch_quantization.utils.reduce_amax

Function to get absolute maximum of a tensor Follow numpy fashion, which is more generic as pytorch’s

pytorch_quantization.utils.reduce_amax.reduce_amax(input, axis=None, keepdims=True)

Compute the absolute maximum value of a tensor.

Reduces input_tensor along the dimensions given in axis. Unless keepdims is true, the rank of the tensor is reduced by 1 for each entry in axis. If keepdims is true, the reduced dimensions are retained with length 1.

Note

Gradient computeation is disabled as this function is never meant learning reduces amax

Parameters:
  • input – Input tensor

  • axis – The dimensions to reduce. None or int or tuple of ints. If None (the default), reduces all dimensions. Must be in the range [-rank(input_tensor), rank(input_tensor)).

  • keepdims – A boolean. If true, retains reduced dimensions with length 1. Default True

  • granularity – DEPRECTED. specifies if the statistic has to be calculated at tensor or channel granularity

Returns:

The reduced tensor.

Raises:
  • ValueError – Any axis which doesn’t make sense or is not supported

  • ValueError – If unknown granularity is passed in.

Indices