nemo_automodel.components.quantization.qat

View as Markdown

TorchAO Quantization-Aware Training (QAT) helpers for NeMo-AutoModel.

This module provides:

  • QATConfig: Configuration class for QAT settings
  • Thin wrappers to instantiate and apply torchao QAT quantizers to models (prepare)
  • Toggle fake-quant on/off during training (for delayed fake-quant)

Module Contents

Classes

NameDescription
QATConfigConfiguration for Quantization-Aware Training (QAT).

Functions

NameDescription
get_disable_fake_quant_fnReturn the disable fake-quant function for a given quantizer mode.
get_enable_fake_quant_fnReturn the enable fake-quant function for a given quantizer mode.
get_quantizer_modeReturn a short mode string for a known torchao QAT quantizer.
prepare_qat_modelApply a torchao QAT quantizer to the given model.

Data

_DISABLE_FN_BY_MODE

_ENABLE_FN_BY_MODE

_QUANTIZER_TO_MODE

__all__

logger

API

class nemo_automodel.components.quantization.qat.QATConfig(
quantizer_type: typing.Literal['int8_dynact_int4weight', 'int4_weight_only'] = 'int8_dynact_int4weight',
quantizer_kwargs = {}
)
Dataclass

Configuration for Quantization-Aware Training (QAT).

This config controls how QAT quantizers are instantiated and applied to models. QAT is enabled when this config is provided to from_pretrained/from_config.

nemo_automodel.components.quantization.qat.QATConfig.create_quantizer()

Create and return the appropriate QAT quantizer based on config.

Returns:

A torchao QAT quantizer instance (Int8DynActInt4WeightQATQuantizer

Raises:

  • ValueError: If quantizer_type is not recognized.
nemo_automodel.components.quantization.qat.QATConfig.to_dict() -> typing.Dict[str, typing.Any]

Convert config to dictionary.

nemo_automodel.components.quantization.qat.get_disable_fake_quant_fn(
mode: str
) -> typing.Optional[typing.Callable]

Return the disable fake-quant function for a given quantizer mode.

nemo_automodel.components.quantization.qat.get_enable_fake_quant_fn(
mode: str
) -> typing.Optional[typing.Callable]

Return the enable fake-quant function for a given quantizer mode.

nemo_automodel.components.quantization.qat.get_quantizer_mode(
quantizer: object
) -> typing.Optional[str]

Return a short mode string for a known torchao QAT quantizer.

Returns None when the quantizer is unrecognized.

nemo_automodel.components.quantization.qat.prepare_qat_model(
model,
quantizer
) -> tuple[object, typing.Optional[str]]

Apply a torchao QAT quantizer to the given model.

Returns the (possibly wrapped) model and a mode string if recognized.

nemo_automodel.components.quantization.qat._DISABLE_FN_BY_MODE = {'8da4w-qat': disable_8da4w_fake_quant, '4w-qat': disable_4w_fake_quant}
nemo_automodel.components.quantization.qat._ENABLE_FN_BY_MODE = {'8da4w-qat': enable_8da4w_fake_quant, '4w-qat': enable_4w_fake_quant}
nemo_automodel.components.quantization.qat._QUANTIZER_TO_MODE = {Int8DynActInt4WeightQATQuantizer: '8da4w-qat', Int4WeightOnlyQATQuantizer: '4w-...
nemo_automodel.components.quantization.qat.__all__ = ['QATConfig', 'get_quantizer_mode', 'get_disable_fake_quant_fn', 'get_enable_fak...
nemo_automodel.components.quantization.qat.logger = logging.getLogger(__name__)