Important

NeMo 2.0 is an experimental feature and currently released in the dev container only: nvcr.io/nvidia/nemo:dev. Please refer to the Migration Guide for information on getting started.

NeMo Framework Quantization Aware Training (QAT) for Llama2 SFT Model

Project Description

Learning Goals

Depending on the type of post-training quantization (PTQ) performed on a model, the model quality may degrade due to the loss of precision. Quantization Aware Training (QAT) is the technique of fine-tuning a quantized model to recover model quality degradation. During QAT, the quantization scaling factors computed during PTQ are frozen and the model weights are fine-tuned. While QAT requires much more compute resources than PTQ, it is highly effective in recovering model quality. Since QAT is essentially additional fine-tuning after PTQ, the supported model families are the same as for PTQ.

In this project, you will learn how to perform QAT on a quantized Llama 2 SFT model.

Prerequisites

This QAT playbook is an extension to the Llama2 SFT and Post-training Quantization playbooks. We will first quantize an SFT model (similar to the Llama2 SFT playbook) using PTQ, and then run SFT on the quantized model (QAT) before deploying it with TensorRT-LLM. Make sure you have completed the Llama2 SFT (required) and Post-training Quantization (recommended) playbooks before starting this playbook for better understanding.

NeMo Tools and Resources

Software Requirements

  • Use the latest NeMo Framework Training container

  • This playbook has been tested on: nvcr.io/nvidia/nemo:24.07. It is expected to work similarly on other environments.

Hardware Requirements

  • NVIDIA DGX H100 and NVIDIA H100 GPUs based on the NVIDIA Hopper architectures.

Data

We will use the same databricks-dolly-15k dataset as in the Llama2 SFT playbook. You can follow the same steps to download and preprocess the dataset.

Run Quantization Aware Training (QAT)

To perform QAT on a calibrated model from PTQ, you need to further fine-tune the model on a downstream task using a small dataset before exporting to TensorRT-LLM. You can reuse the SFT training pipeline for QAT as well. As a rule of thumb, we recommend QAT for 1-10% original training duration and a small learning rate, e.g. 1e-5 for Adam optimizer. However, since we are doing QAT on an SFT model where learning rates and fine-tuning dataset size are already small, we can continue using the same SFT learning rate and dataset size as a starting point for QAT.

Step 1: Obtain the LLama2 SFT model

First, you need to obtain the BF16 Llama2 7B SFT model similar to how its done in the Llama2 SFT playbook. We will run the SFT trainer for 100 steps by setting trainer.max_steps=100 in the SFT script along with other parameters in the playbook. This process will take approximately 2 hours to produce a model checkpoint with a validation loss of approximately 1.15. We will use this checkpoint for PTQ and QAT. For bigger models like Llama2 70B, you can use 50 steps to speed up the process.

Step 2: Perform QAT on the SFT model

Next, we will quantize the SFT model to INT4 precision using a modified version of the script from the Llama2 SFT playbook. Additionally, we’ll include steps for PTQ, QAT, and exporting with TensorRT-LLM. Refer to the config file which contains additional parameters for quantization and export. In addition to the new parameters, ensure that you pass the same parameters you used for SFT training, with the exception that the model restore path will now point to the SFT output .nemo file. The following example command performs Post-Training Quantization (PTQ) on the SFT model checkpoint, followed by another step of SFT (Quantization-Aware Training or QAT). The resulting model can then be exported for TensorRT-LLM inference. The entire process will take approximately 2 to 3 hours to complete for the Llama2 7B model.

torchrun --nproc-per-node 8 examples/nlp/language_modeling/tuning/megatron_gpt_qat.py \
    trainer.num_nodes=1 \
    trainer.devices=8 \
    trainer.precision=bf16 \
    trainer.max_steps=100 \
    model.restore_from_path=<llama2-7b-sft-nemo-path> \
    model.global_batch_size=128 \
    quantization.algorithm=int4 \
    # other parameters from sft training, e.g. dataset paths

As you can see from the logs, the INT4 PTQ model has a validation loss of approximately 1.31 and the QAT model has a validation loss of approximately 1.17 which is very close to the BF16 model loss of 1.15. This script will produce a quantized .nemo checkpoint at the experiment manager log directory (set in the config file) that can be used for further training. It can also produce an exported TensorRT-LLM compatible .qnemo file that can be used for inference by setting the export parameters similar to the Post-training Quantization playbook. Please refer to the playbook for more details on exporting to TensorRT-LLM and deploying to the NVIDIA Triton Inference Server. Note that exporting INT4 quantized model is currently not supported but you can try other quantization formats with this example as well. You can adjust the Quantization-Aware Training (QAT) trainer steps and learning rate as necessary to improve model quality.