8. Extending TAO Quant with a custom backend#

This section services as a guide for advanced users of TAO. The steps outlined here show how to add your own quantization backend to TAO Quant. You implement a small adapter class and register it so the framework can discover it.

8.1. What you will build#

  • A class that implements the QuantizerBase interface: - prepare(model, config) -> model - quantize(model, config) -> model

  • Optional Calibratable interface (if your backend requires a calibration pass): - calibrate(model, dataloader)

  • Optional save_model(model, path) implementation to control how artifacts are written

8.2. Where to look in code#

  • Registry and entry points: nvidia_tao_pytorch/core/quantization/{registry.py, quantizer.py, quantizer_base.py}.

  • Built-in backends: - TorchAO: nvidia_tao_pytorch/core/quantization/backends/torchao/torchao.py. - ModelOpt: nvidia_tao_pytorch/core/quantization/backends/modelopt/modelopt.py.

8.3. Minimal backend skeleton#

from nvidia_tao_pytorch.core.quantization.quantizer_base import QuantizerBase
from nvidia_tao_pytorch.core.quantization.calibratable import Calibratable
from nvidia_tao_pytorch.core.quantization.registry import register_backend
from nvidia_tao_core.config.common.quantization.default_config import ModelQuantizationConfig
import torch.nn as nn

@register_backend("mybackend")
class MyBackend(QuantizerBase):  # or (QuantizerBase, Calibratable) if calibration is supported
    def prepare(self, model: nn.Module, config: ModelQuantizationConfig) -> nn.Module:
        # Validate inputs, insert observers or fake-quant if your library needs it, or return model unchanged
        return model

    def quantize(self, model: nn.Module, config: ModelQuantizationConfig) -> nn.Module:
        # Convert the prepared model to a quantized model using your library
        return model

    # Optional: only if you want to control artifact structure
    def save_model(self, model: nn.Module, path: str) -> None:
        import os, torch
        os.makedirs(path, exist_ok=True)
        torch.save(model.state_dict(), os.path.join(path, "quantized_model_mybackend.pth"))

# Optional if your backend requires calibration
class MyBackendWithCalib(QuantizerBase, Calibratable):
    def calibrate(self, model: nn.Module, dataloader) -> None:
        # Iterate over a representative dataloader to collect ranges/scales
        pass

8.4. Configuration integration#

  • Users select your backend by name in the experiment specification.

quantize:
  backend: "mybackend"
  mode: "weight_only_ptq"   # or your supported mode name(s)
  default_layer_dtype: "int8"
  default_activation_dtype: "native"
  layers:
    - module_name: "Linear"
      weights: { dtype: "int8" }

8.5. Design tips#

  • Validate inputs early and provide clear error messages for unsupported modes and dtypes.

  • Reuse TAO’s layer pattern utility to match modules by name or type.

  • Keep prepare a light validation/no-op if your library inserts quantizers during quantize.

  • Implement Calibratable only if calibration is truly required; otherwise keep the API minimal.

  • Provide a consistent artifact name like quantized_model_<backend>.pth to align with documentation and tooling.

8.6. Testing your backend#

  • Try the end-to-end flow with a small classification model: 1. Update your PYTHONPATH so your backend import is discoverable, or place it under nvidia_tao_pytorch/core/quantization/backends/. 2. Run tao classification_pyt quantize -e <specification.yaml> with backend: mybackend. 3. Evaluate the produced checkpoint by setting evaluate.is_quantized: true and pointing to the artifact path.

8.7. FAQ#

  • How does TAO discover my backend? - Backends self-register via the @register_backend("name") decorator at import time. Ensure your module is imported (e.g., placed under the built-in backends package or imported by your app before use).

  • Can I add new modes or dtypes? - Yes. Validate them in your backend and document them clearly for users.