8. Extending TAO Quant with a custom backend#
This section services as a guide for advanced users of TAO. The steps outlined here show how to add your own quantization backend to TAO Quant. You implement a small adapter class and register it so the framework can discover it.
8.1. What you will build#
A class that implements the
QuantizerBase
interface: -prepare(model, config) -> model
-quantize(model, config) -> model
Optional
Calibratable
interface (if your backend requires a calibration pass): -calibrate(model, dataloader)
Optional
save_model(model, path)
implementation to control how artifacts are written
8.2. Where to look in code#
Registry and entry points:
nvidia_tao_pytorch/core/quantization/{registry.py, quantizer.py, quantizer_base.py}
.Built-in backends: - TorchAO:
nvidia_tao_pytorch/core/quantization/backends/torchao/torchao.py
. - ModelOpt:nvidia_tao_pytorch/core/quantization/backends/modelopt/modelopt.py
.
8.3. Minimal backend skeleton#
from nvidia_tao_pytorch.core.quantization.quantizer_base import QuantizerBase
from nvidia_tao_pytorch.core.quantization.calibratable import Calibratable
from nvidia_tao_pytorch.core.quantization.registry import register_backend
from nvidia_tao_core.config.common.quantization.default_config import ModelQuantizationConfig
import torch.nn as nn
@register_backend("mybackend")
class MyBackend(QuantizerBase): # or (QuantizerBase, Calibratable) if calibration is supported
def prepare(self, model: nn.Module, config: ModelQuantizationConfig) -> nn.Module:
# Validate inputs, insert observers or fake-quant if your library needs it, or return model unchanged
return model
def quantize(self, model: nn.Module, config: ModelQuantizationConfig) -> nn.Module:
# Convert the prepared model to a quantized model using your library
return model
# Optional: only if you want to control artifact structure
def save_model(self, model: nn.Module, path: str) -> None:
import os, torch
os.makedirs(path, exist_ok=True)
torch.save(model.state_dict(), os.path.join(path, "quantized_model_mybackend.pth"))
# Optional if your backend requires calibration
class MyBackendWithCalib(QuantizerBase, Calibratable):
def calibrate(self, model: nn.Module, dataloader) -> None:
# Iterate over a representative dataloader to collect ranges/scales
pass
8.4. Configuration integration#
Users select your backend by name in the experiment specification.
quantize:
backend: "mybackend"
mode: "weight_only_ptq" # or your supported mode name(s)
default_layer_dtype: "int8"
default_activation_dtype: "native"
layers:
- module_name: "Linear"
weights: { dtype: "int8" }
8.5. Design tips#
Validate inputs early and provide clear error messages for unsupported modes and dtypes.
Reuse TAO’s layer pattern utility to match modules by name or type.
Keep
prepare
a light validation/no-op if your library inserts quantizers duringquantize
.Implement
Calibratable
only if calibration is truly required; otherwise keep the API minimal.Provide a consistent artifact name like
quantized_model_<backend>.pth
to align with documentation and tooling.
8.6. Testing your backend#
Try the end-to-end flow with a small classification model: 1. Update your PYTHONPATH so your backend import is discoverable, or place it under
nvidia_tao_pytorch/core/quantization/backends/
. 2. Runtao classification_pyt quantize -e <specification.yaml>
withbackend: mybackend
. 3. Evaluate the produced checkpoint by settingevaluate.is_quantized: true
and pointing to the artifact path.
8.7. FAQ#
How does TAO discover my backend? - Backends self-register via the
@register_backend("name")
decorator at import time. Ensure your module is imported (e.g., placed under the built-in backends package or imported by your app before use).Can I add new modes or dtypes? - Yes. Validate them in your backend and document them clearly for users.