Quantization-Aware Training (QAT) in NeMo Automodel
Quantization-Aware Training (QAT) in NeMo Automodel
NeMo Automodel supports Quantization-Aware Training (QAT) for Supervised Fine-Tuning (SFT) using TorchAO. QAT simulates quantization effects during the training process, allowing models to adapt to lower precision representations while learning. This approach produces quantized models that maintain significantly higher accuracy compared to applying quantization after training is complete.
What is Quantization-Aware Training?
Quantization-Aware Training simulates the effects of quantization during the training process. By introducing fake quantization operations in the forward pass, the model learns to adapt to lower precision representations, maintaining better accuracy when deployed with actual quantization.
Benefits of QAT
- Better accuracy: Models trained with QAT maintain higher accuracy when quantized compared to post-training quantization
- Efficient deployment: Quantized models require less memory and compute resources
- Edge device support: Enables deployment on resource-constrained devices
- Production optimization: Reduces inference costs while maintaining model quality
QAT vs. Post-Training Quantization
Requirements
To use QAT in NeMo Automodel, you need:
- Software: TorchAO library must be installed
- Hardware: Compatible NVIDIA GPU (recommended: A100 or newer)
- Model: Any supported model architecture for SFT
Install TorchAO
Make sure you have TorchAO installed. Follow the installation guide for TorchAO.
How QAT Works in NeMo Automodel
NeMo Automodel integrates TorchAO’s QAT quantizers into the training pipeline. During training:
- Model preparation: The quantizer prepares the model by inserting fake quantization operations
- Forward pass: Weights and activations are quantized using fake quantization
- Backward pass: Gradients flow through the fake quantization operations
- Weight updates: Model learns to minimize loss while accounting for quantization effects
Supported Quantization Schemes
NeMo Automodel supports two TorchAO QAT quantizers:
Int8 Dynamic Activation + Int4 Weight (8da4w-qat)
- Quantizer:
Int8DynActInt4WeightQATQuantizer - Activations: INT8 with dynamic quantization
- Weights: INT4 quantization
- Use case: Balanced accuracy and efficiency
- Memory savings: ~4x compared to FP16/BF16
Int4 Weight-Only (4w-qat)
- Quantizer:
Int4WeightOnlyQATQuantizer - Activations: Full precision
- Weights: INT4 quantization
- Use case: Maximum memory savings with minimal accuracy loss
- Memory savings: ~4x for weights only
Configuration
To enable QAT in your training configuration, you need to specify the quantizer in your YAML configuration file.
Basic Configuration
Int4 Weight-Only Configuration
Configuration Parameters
Delayed Fake Quantization
You can optionally delay the activation of fake quantization to allow the model to train normally for a few steps before introducing quantization effects:
Training Workflow
1. Prepare Your Configuration
Create a YAML configuration file with QAT enabled:
2. Run Training
Launch training with your QAT-enabled configuration:
3. Monitor Training
During training, the model will:
- Apply fake quantization to weights and activations
- Learn to minimize loss while accounting for quantization effects
- Produce checkpoints that can be converted to actual quantized models
4. Deploy Quantized Model
After training, convert the QAT checkpoint to a fully quantized model for deployment:
Performance Considerations
Training Performance
- Training time: QAT adds overhead during training due to fake quantization operations
- Memory usage: Similar to full-precision training during the training phase
- Convergence: May require slightly more training steps to converge compared to full-precision training
Inference Performance
After converting to actual quantization:
- Speed: 2-4x faster inference depending on hardware and model size
- Memory: ~4x reduction in model size
- Accuracy: Minimal degradation compared to full-precision models (typically <1% difference)
When to Use QAT
QAT is most beneficial when:
- Deploying to production: Where inference efficiency is critical
- Edge devices: Resource-constrained environments
- Large-scale serving: Reducing infrastructure costs
- Accuracy is important: When post-training quantization causes unacceptable accuracy loss
When Not to Use QAT
Consider alternatives when:
- Quick prototyping: Post-training quantization is faster
- Small models: Quantization overhead may not be worth it
- Limited training resources: QAT requires retraining the model
- Accuracy is not critical: Post-training quantization may be sufficient
Best Practices
1. Start with Post-Training Quantization
Before investing in QAT, try post-training quantization to establish a baseline:
If accuracy is acceptable, you may not need QAT.
2. Choose the Right Quantization Scheme
- 8da4w-qat: Best balance of accuracy and efficiency for most use cases
- 4w-qat: Use when memory is the primary constraint and activations can remain full precision
3. Tune Group Size
The groupsize parameter affects the granularity of quantization:
- Smaller groups (128): Better accuracy, slightly more memory
- Larger groups (256): More efficient, may have minor accuracy impact
Start with 256 and reduce to 128 if accuracy is insufficient.
4. Monitor Validation Metrics
Track validation metrics closely during QAT training:
- Compare against full-precision baseline
- Watch for convergence issues
- Adjust learning rate if needed (QAT may benefit from slightly lower learning rates)
5. Use Delayed Fake Quantization
For better convergence, consider delaying fake quantization:
This allows the model to learn basic patterns before introducing quantization constraints.
Accuracy vs. Efficiency Trade-offs
Expected Accuracy Impact
Optimization Strategies
If accuracy is below expectations:
- Increase training steps: QAT may need more training to converge
- Reduce learning rate: Lower learning rates can help with quantization constraints
- Use 8da4w instead of 4w: Better accuracy with minimal additional cost
- Reduce group size: Smaller groups provide finer-grained quantization
- Delay fake quantization: Give the model time to learn before quantizing
Limitations and Known Issues
Current Limitations
- SFT only: QAT is currently supported for Supervised Fine-Tuning tasks only
- Model compatibility: Not all model architectures may be compatible with TorchAO quantizers
- Training overhead: QAT adds computational overhead during training
Troubleshooting
Issue: Training diverges or doesn’t converge
Solution: Try these approaches:
- Reduce learning rate by 2-5x
- Increase
delay_fake_quant_stepsto 2000-5000 - Use a smaller group size (128 instead of 256)
- Verify your baseline model trains successfully without QAT
Issue: Accuracy is significantly worse than expected
Solution:
- Ensure you’re comparing against the same baseline (same training steps, data, etc.)
- Try 8da4w quantization instead of 4w
- Reduce group size to 128
- Increase training steps by 20-30%
Issue: Out of memory during training
Solution:
- QAT should have similar memory usage to full-precision training
- Reduce batch size if needed
- Use gradient accumulation to maintain effective batch size