Distillation#
NeMo 2.0 provides a streamlined setup for Knowledge Distillation (KD) training, making it easy to enable and integrate into your workflow. This section explains how to use this feature effectively.
KD is a technique where a pre-trained model (the “teacher”) transfers its learned knowledge to a second model (the “student”), which is typically smaller and faster. This process helps the student model learn more efficiently by mimicking the behavior of the teacher. KD offers two key advantages over traditional training: faster convergence and higher final accuracy.
In NeMo, KD is enabled by NVIDIA TensorRT Model Optimizer (ModelOpt) — a library to optimize deep-learning models for inference on GPUs.
Knowledge Distillation Process#
The KD process involves these steps:
Loads Checkpoints: Loads both the student and teacher model checkpoints. They must both support the same parallelism strategy.
Replaces Loss Function: Replaces the standard loss function with the KL-Divergence between the output logits (and potentially additional losses between pairs of intermediate model states).
Trains Models: Runs forward passes on both models, but executes the backward pass only on the student model.
Saves Checkpoints: Saves only the student model checkpoint, allowing it to be used later in the same manner as before.
Limitations#
Only GPT-based NeMo 2.0 checkpoints are supported.
If Pipeline Parallelism is enabled, intermediate state based KD losses are only supported on the final pipeline stage.
Configuration#
You can configure the KD process via a YAML file. An example configuration file is shown below:
logit_layers: ["output_layer", "output_layer"]
intermediate_layer_pairs:
- ["decoder.final_layernorm", "decoder.final_layernorm"]
skip_lm_loss: true
kd_loss_scale: 1.0
The components of the configuration file are as follows:
logit_layers
: The layer names of student and teacher model logit layers. These names correspond to the PyTorch submodule attributes of the Megatron Core model. (For GPT-based models, this is"output_layer"
).intermediate_layer_pairs
: A list of pairs of intermediate layer names. These pairs will by default have a Cosine-Similarity loss between them, and if tensor-parallelism is enabled, these layers must have sequence parallel outputs (i.e. LayerNorms), as Cosine loss cannot have a split hidden dimension.skip_lm_loss
: Whether to skip the default language modeling (LM) loss. Iffalse
, it will be added to the distillation loss. (Note it consumes more memory)kd_loss_scale
: Relative scale factor for the distillation loss. The cumulative logits-and-intermediate loss gets scaled tokd_loss_scale
times the magnitude of the LM loss. Not used ifskip_lm_loss
istrue
.
Use NeMo-Run Recipes#
Note
Prerequisite: Before proceeding, please follow the example in Quickstart with NeMo-Run to familiarize yourself with NeMo-Run first.
import nemo_run as run
from nemo.collections import llm
from nemo.collections.llm.modelopt.recipes import distillation_recipe
recipe = distillation_recipe(
student_model_path="path/to/student/nemo2-checkpoint/",
teacher_model_path="path/to/teacher/nemo2-checkpoint/",
distillation_config_path="path/to/distill-config.yaml",
dir="./distill_logs", # Path to store logs and checkpoints
name="distill_testrun",
num_nodes=1,
num_gpus_per_node=8,
)
# Override the configuration with desired components:
recipe.data = run.Config(llm.PreTrainingDataModule, ...)
recipe.trainer.strategy.tensor_model_parallel_size = 8
...
run.run(recipe)
Use with torchrun
or Slurm#
Alternatively, you can run a traditional script for more direct transparency and control:
STUDENT_CKPT="path/to/student/nemo2-checkpoint/"
TEACHER_CKPT="path/to/teacher/nemo2-checkpoint/"
DISTILLATION_CONFIG="path/to/distill-config.yaml"
DATA_PATHS="1.0 path/to/tokenized/data"
SEQUENCE_LEN=8192
MICRO_BATCHSIZE=1
GLOBAL_BATCHSIZE=4
STEPS=100
TP=8
CP=1
PP=1
DP=1
NUM_NODES=1
DEVICES_PER_NODE=8
NAME="distill_testrun"
LOG_DIR="./distill_logs/"
launch_cmd="torchrun --nproc_per_node=$(($TP * $CP * $PP * $DP))"
${launch_cmd} scripts/llm/gpt_train.py \
--name ${NAME} \
--model_path ${STUDENT_CKPT} \
--teacher_path ${TEACHER_CKPT} \
--kd_config ${DISTILLATION_CONFIG} \
--tp_size ${TP} \
--cp_size ${CP} \
--pp_size ${PP} \
--devices ${DEVICES_PER_NODE} \
--num_nodes ${NUM_NODES} \
--log_dir ${LOG_DIR} \
--max_steps ${STEPS} \
--gbs ${GLOBAL_BATCHSIZE} \
--mbs ${MICRO_BATCHSIZE} \
--data_paths ${DATA_PATHS} \
--seq_length ${SEQUENCE_LEN}
Perform SFT Knowledge Distillation#
To perform SFT Knowledge Distillation on a chat dataset, follow the script above and add the –tokenizer and –use-chat-data arguments as well. See scripts/llm/gpt_train.py for full argument descriptions.