Important
You are viewing the NeMo 2.0 documentation. This release introduces significant changes to the API and a new library, NeMo Run. We are currently porting all features from NeMo 1.0 to 2.0. For documentation on previous versions or features not yet available in 2.0, please refer to the NeMo 24.07 documentation.
Quantization#
NeMo offers two quantization methods: Post-Training Quantization and Quantization-Aware Training suitable for converting a FP16/BF16 model to a lower precision format.
The following sections detail how to use quantization in NeMo.
Post-Training Quantization (PTQ)#
PTQ enables deploying a model in a low-precision format – FP8, INT4, or INT8 – for efficient serving. Different quantization methods are available including FP8 quantization, INT8 SmoothQuant, and INT4 AWQ.
Model quantization has three primary benefits: reduced model memory requirements, lower memory bandwidth pressure and increased inference throughput.
In NeMo, quantization is enabled by the NVIDIA TensorRT Model Optimizer (ModelOpt) – a library to quantize and compress deep learning models for optimized inference on GPUs.
The quantization process consists of the following steps:
Loading a model checkpoint using an appropriate parallelism strategy
Calibrating the model to obtain appropriate algorithm-specific scaling factors
Producing TensorRT-LLM checkpoint with model config (json), quantized weights (safetensors) and tokenizer config (yaml).
Loading models requires using an ModelOpt spec defined in nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec module. Typically the calibration step is lightweight and uses a small dataset to obtain appropriate statistics for scaling tensors. The output directory produced is ready to be used to build a serving engine with the NVIDIA TensorRT-LLM library, see Deploy NeMo Models by Exporting TensorRT-LLM. We refer to this checkpoint as .qnemo checkpoint henceforth.
Quantization algorithm can also be conveniently set to "null"
to perform only the weights export step using default precision for TensorRT-LLM deployment. This is useful to obtain baseline performance and accuracy results for comparison.
Support Matrix#
Table below presents a verified model support matrix for popular LLM architectures. Support for other model families is experimental.
Model Name |
Model Parameters |
Decoder Type |
FP8 |
INT8 SQ |
INT4 AWQ |
---|---|---|---|---|---|
GPT |
2B, 8B, 43B |
gptnext |
✓ |
✓ |
✓ |
Nemotron-3 |
8B, 22B |
gptnext |
✓ |
✓ |
✓ |
Nemotron-4 |
15B, 340B |
gptnext |
✓ |
✓ |
✓ |
Llama 2 |
7B, 13B, 70B |
llama |
✓ |
✓ |
✓ |
Llama 3 |
8B, 70B |
llama |
✓ |
✓ |
✓ |
Llama 3.1 |
8B, 70B, 405B |
llama |
✓ |
✓ |
✓ |
Falcon |
7B, 40B |
falcon |
✗ |
✗ |
✗ |
Gemma 1 |
2B, 7B |
gemma |
✓ |
✓ |
✓ |
StarCoder 1 |
15B |
gpt2 |
✓ |
✓ |
✓ |
StarCoder 2 |
3B, 7B, 15B |
gptnext |
✓ |
✓ |
✓ |
Mistral |
7B |
llama |
✓ |
✓ |
✓ |
Mixtral |
8x7B |
llama |
✗ |
✗ |
✗ |
Selected models are available to download from Hugging Face Hub for testing purposes:
Example#
The example below shows how to quantize the Llama3 70b model into FP8 precision, using tensor parallelism of 8 on a single DGX H100 node. The quantized model is designed for serving using 2 H100 GPUs specified with the export.inference_tensor_parallel
parameter.
The script must be launched correctly with the number of processes equal to tensor parallelism. This is achieved with the torchrun
command below:
CALIB_TP=8
INFER_TP=2
torchrun --nproc-per-node $CALIB_TP examples/nlp/language_modeling/megatron_gpt_ptq.py \
model.restore_from_path=llama3-70b-base-bf16.nemo \
model.tensor_model_parallel_size=$CALIB_TP \
model.pipeline_model_parallel_size=1 \
trainer.num_nodes=1 \
trainer.devices=$CALIB_TP \
trainer.precision=bf16 \
quantization.algorithm=fp8 \
export.decoder_type=llama \
export.inference_tensor_parallel=$INFER_TP \
export.save_path=llama3-70b-base-fp8-qnemo
For large models, the command can be used in a multi-node setting. For example, this can be done with NeMo Framework Launcher using Slurm.
When running PTQ, decoder type needs to be specified as export.decoder_type
parameter to produce a correct TensorRT-LLM checkpoint (see Support Matrix above).
The output directory stores the following files:
llama3-70b-base-fp8-qnemo/
├── config.json
├── rank0.safetensors
├── rank1.safetensors
├── tokenizer.model
└── tokenizer_config.yaml
The next step is to build TensorRT-LLM engine for the checkpoint produced. This can be conveniently built and run using TensorRTLLM
class available in nemo.export
module, see Deploy NeMo Models by Exporting TensorRT-LLM for details.
Quantization-Aware Training (QAT)#
QAT is the technique of fine-tuning a quantized model to recover model quality degradation due to quantization. During QAT, the quantization scaling factors computed during PTQ are frozen and the model weights are fine-tuned. While QAT requires much more compute resources than PTQ, it is highly effective in recovering model quality. To perform QAT on a calibrated model from PTQ, you need to further fine-tune the model on a downstream task using a small dataset before exporting to TensorRT-LLM. You can reuse your training pipeline for QAT. As a rule of thumb, we recommend QAT for 1-10% original training duration and a small learning rate, e.g. 1e-5 for Adam optimizer. If you are doing QAT on an SFT model where learning rates and finetuning dataset size are already small, you can continue using the same SFT learning rate and dataset size as a starting point for QAT. Since QAT is done after PTQ, the supported model families are the same as for PTQ.
Example#
The example below shows how to perform PTQ and QAT on a Supervised Finetuned Llama2 7B model to INT4 precision. The script is tested using tensor parallelism of 8 on 8x RTX 6000 Ada 48GB GPUs. Alternatively, a single DGX A100 node with 8x 40GB GPUs can be used for the same purpose. For bigger models like Llama2 70B, you may need to use one or more DGX H100 nodes with 8x 80GB GPUs each.
The example is a modified version of the SFT with Llama 2 playbook.
Please refer to the playbook for more details on setting up a BF16 NeMo model and the databricks-dolly-15k
instruction dataset.
First we will run the SFT example command from the playbook as-is to train a Llama2 7B SFT model for 100 steps.
Make sure to change trainer.max_steps=50
to trainer.max_steps=100
for the megatron_gpt_finetuning.py script.
This will take ~2 hours to produce a model checkpoint with validation loss approximately 1.15
that we will use for PTQ and QAT next.
For Quantization, we use a modified version of the sft script and config file which includes the quantization and TensorRT-LLM export support. Along with the new parameters, make sure to pass the same parameters you passed for SFT training except the model restore path will be the SFT output .nemo file. The below example command will perform PTQ on the SFT model checkpoint followed by SFT again (QAT) which can then be exported for TensorRT-LLM inference. The script will take ~2-3 hours to complete.
TP=8
torchrun --nproc-per-node $TP examples/nlp/language_modeling/tuning/megatron_gpt_qat.py \
trainer.num_nodes=1 \
trainer.devices=$TP \
trainer.precision=bf16 \
trainer.max_steps=100 \
model.restore_from_path=<llama2-7b-sft-nemo-path> \
model.global_batch_size=128 \
quantization.algorithm=int4 \
# other parameters from sft training
As you can see from the logs, the INT4 PTQ model has a validation loss of approximately 1.31
and the QAT model has a validation loss of approximately 1.17
which is very close to the BF16 model loss of 1.15
.
This script will produce a quantized .nemo checkpoint at the experiment manager log directory (in the config yaml file) that can be used for further training.
It can also optionally produce an exported TensorRT-LLM engine directory or a .qnemo file that can be used for inference by setting the export
parameters similar to the PTQ example.
Note that you may tweak the QAT trainer steps and learning rate if needed to achieve better model quality.
Known Issues#
PTQ for updated Nemo 2.0 API is currently under development, see NEMO-10642.
References#
Please refer to the following papers for more details on quantization techniques: