Quantization-Aware Training (QAT) in NeMo Automodel#
NeMo Automodel supports Quantization-Aware Training (QAT) for Supervised Fine-Tuning (SFT) using TorchAO. QAT simulates quantization effects during the training process, allowing models to adapt to lower precision representations while learning. This approach produces quantized models that maintain significantly higher accuracy compared to applying quantization after training is complete.
What is Quantization-Aware Training?#
Quantization-Aware Training simulates the effects of quantization during the training process. By introducing fake quantization operations in the forward pass, the model learns to adapt to lower precision representations, maintaining better accuracy when deployed with actual quantization.
Benefits of QAT#
Better accuracy: Models trained with QAT maintain higher accuracy when quantized compared to post-training quantization
Efficient deployment: Quantized models require less memory and compute resources
Edge device support: Enables deployment on resource-constrained devices
Production optimization: Reduces inference costs while maintaining model quality
QAT vs. Post-Training Quantization#
Aspect |
QAT |
Post-Training Quantization |
|---|---|---|
Accuracy |
Higher - model adapts during training |
Lower - no adaptation |
Training time |
Longer - requires retraining |
None - applied after training |
Use case |
Production deployments requiring best accuracy |
Quick prototyping or less critical applications |
Flexibility |
Can fine-tune quantization parameters |
Limited to fixed quantization schemes |
Requirements#
To use QAT in NeMo Automodel, you need:
Software: TorchAO library must be installed
Hardware: Compatible NVIDIA GPU (recommended: A100 or newer)
Model: Any supported model architecture for SFT
Install TorchAO#
Make sure you have TorchAO installed. Follow the installation guide for TorchAO.
pip install torchao
How QAT Works in NeMo Automodel#
NeMo Automodel integrates TorchAO’s QAT quantizers into the training pipeline. During training:
Model preparation: The quantizer prepares the model by inserting fake quantization operations
Forward pass: Weights and activations are quantized using fake quantization
Backward pass: Gradients flow through the fake quantization operations
Weight updates: Model learns to minimize loss while accounting for quantization effects
Supported Quantization Schemes#
NeMo Automodel supports two TorchAO QAT quantizers:
Int8 Dynamic Activation + Int4 Weight (8da4w-qat)#
Quantizer:
Int8DynActInt4WeightQATQuantizerActivations: INT8 with dynamic quantization
Weights: INT4 quantization
Use case: Balanced accuracy and efficiency
Memory savings: ~4x compared to FP16/BF16
Int4 Weight-Only (4w-qat)#
Quantizer:
Int4WeightOnlyQATQuantizerActivations: Full precision
Weights: INT4 quantization
Use case: Maximum memory savings with minimal accuracy loss
Memory savings: ~4x for weights only
Configuration#
To enable QAT in your training configuration, you need to specify the quantizer in your YAML configuration file.
Basic Configuration#
# Enable QAT with Int8 Dynamic Activation + Int4 Weight quantization
qat:
enabled: true
quantizer:
_target_: torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer
groupsize: 256
Int4 Weight-Only Configuration#
# Enable QAT with Int4 Weight-Only quantization
qat:
enabled: true
quantizer:
_target_: torchao.quantization.qat.Int4WeightOnlyQATQuantizer
groupsize: 256
Configuration Parameters#
Parameter |
Type |
Description |
|---|---|---|
|
bool |
Enable or disable QAT |
|
str |
Fully qualified class name of the TorchAO quantizer |
|
int |
Group size for weight quantization (typically 128 or 256) |
Delayed Fake Quantization#
You can optionally delay the activation of fake quantization to allow the model to train normally for a few steps before introducing quantization effects:
qat:
enabled: true
quantizer:
_target_: torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer
groupsize: 256
delay_fake_quant_steps: 1000 # Enable fake quant after 1000 steps
Training Workflow#
1. Prepare Your Configuration#
Create a YAML configuration file with QAT enabled:
model:
model_name: meta-llama/Llama-3.2-1B
task:
type: sft
qat:
enabled: true
quantizer:
_target_: torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer
groupsize: 256
trainer:
max_steps: 10000
val_check_interval: 500
2. Run Training#
Launch training with your QAT-enabled configuration:
uv run torchrun --nproc-per-node=8 examples/llm_finetune/finetune.py --config your_qat_config.yaml
3. Monitor Training#
During training, the model will:
Apply fake quantization to weights and activations
Learn to minimize loss while accounting for quantization effects
Produce checkpoints that can be converted to actual quantized models
4. Deploy Quantized Model#
After training, convert the QAT checkpoint to a fully quantized model for deployment:
from torchao.quantization import quantize_
# Load your trained model
model = load_model_from_checkpoint(checkpoint_path)
# Apply actual quantization (not fake quantization)
quantize_(model, int8_dynamic_activation_int4_weight())
# Deploy the quantized model
model.eval()
Performance Considerations#
Training Performance#
Training time: QAT adds overhead during training due to fake quantization operations
Memory usage: Similar to full-precision training during the training phase
Convergence: May require slightly more training steps to converge compared to full-precision training
Inference Performance#
After converting to actual quantization:
Speed: 2-4x faster inference depending on hardware and model size
Memory: ~4x reduction in model size
Accuracy: Minimal degradation compared to full-precision models (typically <1% difference)
When to Use QAT#
QAT is most beneficial when:
Deploying to production: Where inference efficiency is critical
Edge devices: Resource-constrained environments
Large-scale serving: Reducing infrastructure costs
Accuracy is important: When post-training quantization causes unacceptable accuracy loss
When Not to Use QAT#
Consider alternatives when:
Quick prototyping: Post-training quantization is faster
Small models: Quantization overhead may not be worth it
Limited training resources: QAT requires retraining the model
Accuracy is not critical: Post-training quantization may be sufficient
Best Practices#
1. Start with Post-Training Quantization#
Before investing in QAT, try post-training quantization to establish a baseline:
# Quick post-training quantization test
from torchao.quantization import quantize_
quantize_(model, int8_dynamic_activation_int4_weight())
If accuracy is acceptable, you may not need QAT.
2. Choose the Right Quantization Scheme#
8da4w-qat: Best balance of accuracy and efficiency for most use cases
4w-qat: Use when memory is the primary constraint and activations can remain full precision
3. Tune Group Size#
The groupsize parameter affects the granularity of quantization:
Smaller groups (128): Better accuracy, slightly more memory
Larger groups (256): More efficient, may have minor accuracy impact
Start with 256 and reduce to 128 if accuracy is insufficient.
4. Monitor Validation Metrics#
Track validation metrics closely during QAT training:
Compare against full-precision baseline
Watch for convergence issues
Adjust learning rate if needed (QAT may benefit from slightly lower learning rates)
5. Use Delayed Fake Quantization#
For better convergence, consider delaying fake quantization:
qat:
delay_fake_quant_steps: 1000 # Let model train normally first
This allows the model to learn basic patterns before introducing quantization constraints.
Accuracy vs. Efficiency Trade-offs#
Expected Accuracy Impact#
Quantization Method |
Typical Accuracy Loss |
Memory Savings |
|---|---|---|
Full Precision (BF16) |
Baseline |
Baseline |
Post-Training Quantization |
1-3% |
4x |
QAT (8da4w) |
<1% |
4x |
QAT (4w) |
<1.5% |
4x (weights only) |
Optimization Strategies#
If accuracy is below expectations:
Increase training steps: QAT may need more training to converge
Reduce learning rate: Lower learning rates can help with quantization constraints
Use 8da4w instead of 4w: Better accuracy with minimal additional cost
Reduce group size: Smaller groups provide finer-grained quantization
Delay fake quantization: Give the model time to learn before quantizing
Limitations and Known Issues#
Current Limitations#
SFT only: QAT is currently supported for Supervised Fine-Tuning tasks only
Model compatibility: Not all model architectures may be compatible with TorchAO quantizers
Training overhead: QAT adds computational overhead during training
Troubleshooting#
Issue: Training diverges or doesn’t converge#
Solution: Try these approaches:
Reduce learning rate by 2-5x
Increase
delay_fake_quant_stepsto 2000-5000Use a smaller group size (128 instead of 256)
Verify your baseline model trains successfully without QAT
Issue: Accuracy is significantly worse than expected#
Solution:
Ensure you’re comparing against the same baseline (same training steps, data, etc.)
Try 8da4w quantization instead of 4w
Reduce group size to 128
Increase training steps by 20-30%
Issue: Out of memory during training#
Solution:
QAT should have similar memory usage to full-precision training
Reduce batch size if needed
Use gradient accumulation to maintain effective batch size