Important
You are viewing the NeMo 2.0 documentation. This release introduces significant changes to the API and a new library, NeMo Run. We are currently porting all features from NeMo 1.0 to 2.0. For documentation on previous versions or features not yet available in 2.0, please refer to the NeMo 24.07 documentation.
Distillation#
NeMo 2.0 offers an easy-to-enable Knowledge Distillation (KD) training setup. The following section explains how to use it.
Knowledge Distillation#
KD involves using information from an existing trained model to train a second (usually smaller, faster) model, thereby “distilling” knowledge from one to the other.
Distillation has two primary benefits: faster convergence and higher final accuracy than traditional training.
In NeMo, distillation is enabled by the NVIDIA TensorRT Model Optimizer (ModelOpt) library – a library to optimize deep-learning models for inference on GPUs.
Logits-Distillation Process#
The logits-distillation process involves these steps:
Loads Checkpoints: Loads both the student and teacher model checkpoints. They must both support the same parallelism strategy.
Replaces Loss Function: Replaces the standard loss function with the KL-Divergence between the output logits.
Trains Models: Runs forward passes on both models, but executes the backward pass only on the student model.
Saves Checkpoints: Saves only the student model checkpoints, allowing it to be used later in the same manner as before.
Limitations#
Only GPT-based NeMo 2.0 checkpoints are supported.
Only logit-pair distillation is enabled for now.
Example#
The examples below show how to run the distillation script, given any NeMo 2.0 checkpoints.
Use NeMo-Run Recipes#
Note
Prerequisite: Before proceeding, please follow the example in Quickstart with NeMo-Run to familiarize yourself with NeMo-Run first.
import nemo_run as run
from nemo.collections.llm.distillation.recipe import distillation_recipe
recipe = distillation_recipe(
student_model_path="path/to/student/nemo2-checkpoint/",
teacher_model_path="path/to/teacher/nemo2-checkpoint/",
dir="./distill_logs", # Path to store logs and checkpoints
name="distill_testrun",
num_nodes=1,
num_gpus_per_node=8,
)
# Override the configuration with desired components:
# recipe.data = run.Config(...)
# recipe.trainer = run.Config(...)
...
run.run(recipe)
Use the Distillation Script with torchrun
or Slurm#
Alternatively, you can run a traditional script with a finer degree of customization.
STUDENT_CKPT="path/to/student/nemo2-checkpoint/"
TEACHER_CKPT="path/to/teacher/nemo2-checkpoint/"
DATA_PATHS="1.0 path/to/tokenized/data"
SEQUENCE_LEN=8192
MICRO_BATCHSIZE=1
GLOBAL_BATCHSIZE=4
STEPS=100
TP=8
CP=1
PP=1
DP=1
NUM_NODES=1
DEVICES_PER_NODE=8
NAME="distill_testrun"
LOG_DIR="./distill_logs/"
launch_cmd="torchrun --nproc_per_node=$(($TP * $CP * $PP * $DP))"
${launch_cmd} scripts/llm/gpt_distillation.py \
--name ${NAME} \
--student_path ${STUDENT_CKPT} \
--teacher_path ${TEACHER_CKPT} \
--tp_size ${TP} \
--cp_size ${CP} \
--pp_size ${PP} \
--devices ${DEVICES_PER_NODE} \
--num_nodes ${NUM_NODES} \
--log_dir ${LOG_DIR} \
--max_steps ${STEPS} \
--gbs ${GLOBAL_BATCHSIZE} \
--mbs ${MICRO_BATCHSIZE} \
--data_paths ${DATA_PATHS} \
--seq_length ${SEQUENCE_LEN}