NeMo Framework Post-Training Quantization (PTQ) with Nemotron4 and Llama3

Project Description

Learning Goals

Post-training quantization (PTQ) is a technique in machine learning that reduces a trained model’s memory and computational footprint. In this playbook, you’ll learn how to apply PTQ to two large language models (LLMs), Nemotron4-340B and Llama3-70B, enabling export to TRTLLM and deployment via PyTriton in FP8 precision for efficient serving.

NeMo Tools and Resources

Software Requirements

  • Use the latest NeMo Framework Training container

  • This playbook has been tested on: nvcr.io/nvidia/nemo:24.05. It is expected to work similarly on other environments.

Hardware Requirements

  • NVIDIA DGX H100 and NVIDIA H100 GPUs based on the NVIDIA Hopper architectures.

Preparing NeMo checkpoint for Nemotron4-340B and Llama3-70B

Nemotron4-340B Checkpoint Preparation

Nemotron4-340B can be downloaded via huggingface/nvidia

from huggingface_hub import snapshot_download

snapshot_download(
    repo_id="nvidia/Nemotron-4-340B-Base",
    local_dir="nemotron4-340b-base",
    local_dir_use_symlinks=False
)

Llama3-70B Checkpoint Preparation

Llama3-70B can be downloaded via huggingface/meta-llama

User needs to be approved by Meta Llama3 Community License Agreement first in order to download the checkpoint. Use your HuggingFace API token to download the model.

from huggingface_hub import snapshot_download

snapshot_download(
    repo_id="meta-llama/Meta-Llama-3-70B",
    local_dir="llama3-70b-base",
    local_dir_use_symlinks=False,
    token=<YOUR HF TOKEN>
)

Convert the Llama3-70B HF checkpoint into .nemo format

Run the container using the following command:

docker run --gpus device=1 --shm-size=2g --net=host --ulimit memlock=-1 --rm -it -v ${PWD}:/workspace -w /workspace -v ${PWD}/results:/results nvcr.io/nvidia/nemo:24.05 bash

Convert the Hugging Face model to .nemo model:

python /opt/NeMo/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py --input_name_or_path=./llama3-70b-base/ --output_path=llama3-70b-base.nemo

Extract the .nemo to a folder to avoid memory issue while loading the model

tar -xvf llama3-70b-base.nemo -C llama3-70b-base-nemo/

Convert NeMo Checkpoint to qnemo format

“.nemo” versus “.qnemo”

NeMo also offers Post-Training Quantization workflow to convert regular .nemo models into a TensorRT-LLM checkpoint conventionally referred to as .qnemo checkpoints in NeMo. Such a checkpoint can be used with NVIDIA TensorRT-LLM library for efficient inference.

Much as in the case of .nemo checkpoints, a .qnemo checkpoint is a tar file that bundles the model configuration given in config.json file and rank{i}.safetensors files storing model weights for each rank separately. Additionally a tokenizer_config.yaml file is saved which is just tokenizer section from model_config.yaml file from the original NeMo model. This configuration file defines a tokenizer for the model given.

For large quantized LLM using a directory rather than a tar file is recommended. This can be controlled with compress flag on exporting quantized models in PTQ configuration file.

Running Calibration to Generate qnemo Model

Calibrating Nemotron4-340B model requires at least two DGX-8H100 GPUs, job can be submitted through Nemo-Framework-Launcher:

cd NeMo-Framework-Launcher/launcher_scripts

CALIB_PP=2
CALIB_TP=8
INFER_TP=8

python3 main.py \
   ptq=model/quantization \
   stages=["ptq"] \
   launcher_scripts_path=$(pwd) \
   base_results_dir=/results/base \
   "container='${CONTAINER}'" \
   container_mounts=[/models,/results] \
   cluster.partition=batch \
   cluster.job_name_prefix="${SLURM_ACCOUNT}-nemotron_340b_fp8:" \
   cluster.gpus_per_task=null \
   cluster.gpus_per_node=null \
   cluster.srun_args='["--no-container-mount-home", "--mpi=pmix"]' \
   ptq.run.model_train_name=nemotron_340b \
   ptq.run.time_limit=45 \
   ptq.run.results_dir=/results \
   ptq.quantization.algorithm=fp8 \
   ptq.export.decoder_type=gptnext \
   ptq.export.inference_tensor_parallel=${INFER_TP} \
   ptq.export.inference_pipeline_parallel=1 \
   ptq.trainer.precision=bf16 \
   ptq.model.restore_from_path=/models/nemotron4-340b-base \
   ptq.export.save_path=/results/nemotron4-340B-base-fp8-qnemo \
   ptq.model.tensor_model_parallel_size=${CALIB_TP} \
   ptq.pipeline_model_parallel_size=${CALIB_PP}

Note

cluster settings might differ depending on your hardware environment, consult NeMo-Framework-Launcher documentation for cluster related settings.

Calibrating Llama3 70B model rquires at least eight H100 GPUs, job can be directly launched through NeMo:

python examples/nlp/language_modeling/megatron_gpt_quantization.py \
    model.restore_from_path=llama3-70b-base-nemo \
    model.tensor_model_parallel_size=2 \
    model.pipeline_model_parallel_size=1 \
    trainer.num_nodes=1 \
    trainer.precision=bf16 \
    trainer.devices=8 \
    quantization.algorithm=fp8 \
    export.decoder_type=llama \
    export.inference_tensor_parallel=1 \
    export.model_save=llama3-70b-base-fp8-qnemo

Note

The above scripts should be run within the NeMo Docker Container nvcr.io/nvidia/nemo:24.05.

The output directory stores the following files in the case of Nemotron4 340B:

nemotron4-340B-base-fp8-qnemo
├── config.json
├── rank0.safetensors
├── rank1.safetensors
├── rank2.safetensors
├── rank3.safetensors
├── rank4.safetensors
├── rank5.safetensors
├── rank6.safetensors
├── rank7.safetensors
├── tokenizer.model
└── tokenizer_config.yaml

The output in the case of Llama3-70B is analogous.

Exporting to TensorRT-LLM

Option1: Through nemo.export

The TensorRT-LLM engine can be conveniently built and run using TensorRTLLM class available in nemo.export submodule:

from nemo.export import TensorRTLLM

# Export Nemotron4-340B model
trt_llm_exporter = TensorRTLLM(model_dir="NEMOTRON4-340B-base-fp8-trt-llm-engine")
trt_llm_exporter.export(
    nemo_checkpoint_path="nemotron4-340B-base-fp8-qnemo",
    model_type="gptnext",
)
trt_llm_exporter.forward(["Hi, how are you?", "I am good, thanks, how about you?"])

# Export Llama3-70B model
trt_llm_exporter = TensorRTLLM(model_dir="llama3-70b-base-fp8-trt-llm-engine")
trt_llm_exporter.export(
    nemo_checkpoint_path="llama3-70b-base-fp8-qnemo",
    model_type="llama",
)
trt_llm_exporter.forward(["Hi, how are you?", "I am good, thanks, how about you?"])

Option2: Through trt-build

Alternatively, it can also be built directly using trtllm-build command, see TensorRT-LLM documentation:

# Build Nemotron4-340B TRTLLM Engine
trtllm-build \
    --checkpoint_dir nemotron4-340B-base-fp8-qnemo \
    --output_dir NEMOTRON4-340B-base-fp8-trt-llm-engine \
    --max_batch_size 8 \
    --max_input_len 2048 \
    --max_output_len 512 \
    --strongly_typed

The command for Llama3-70B is analogous.

TensorRT-LLM Engine files will be stored to NEMOTRON4-340B-base-fp8-trt-llm-engine and llama3-70b-base-fp8-trt-llm-engine, for example.

Deploy Nemotron/Llama TensorRT-LLM to Triton

You can use the APIs in the deploy module to deploy a TensorRT-LLM model to Triton.

from nemo.export import TensorRTLLM
from nemo.deploy import DeployPyTriton

# Deploy Nemotron Model
trt_llm_exporter = TensorRTLLM(model_dir="NEMOTRON4-340B-base-fp8-trt-llm-engine")

nm = DeployPyTriton(model=trt_llm_exporter, triton_model_name="nemotron4-340b", port=8000)
nm.deploy()
nm.serve()

# Deploy Llama3 Model
trt_llm_exporter = TensorRTLLM(model_dir="llama3-70b-base-fp8-trt-llm-engine")

lm = DeployPyTriton(model=trt_llm_exporter, triton_model_name="llama3-70b", port=8008)
lm.deploy()
lm.serve()

The NeMo Framework provides NemoQueryLLM APIs to send a query to the Triton server for convenience. These APIs are only accessible from the NeMo Framework container.

from nemo.deploy.nlp import NemoQueryLLM

# Send a Query to nemotron server
nq = NemoQueryLLM(url="localhost:8000", model_name="nemotron4-340b")
output = nq.query_llm(prompts=["What is the capital of United States?"], max_output_token=10, top_k=1, top_p=0.0, temperature=1.0)
print(output)

# Send a query to llama server
nq = NemoQueryLLM(url="localhost:8008", model_name="llama3-70b")
output = nq.query_llm(prompts=["What is the capital of United States?"], max_output_token=10, top_k=1, top_p=0.0, temperature=1.0)
print(output)

For more information on a variety of ways of deployment, refer to Deploy NeMo Framework Models