Knowledge Distillation with NeMo AutoModel
This guide walks through fine-tuning a student LLM with the help of a
larger teacher model using the kd (knowledge distillation) recipe.
In particular, we will show how to distill a 3B (meta-llama/Llama-3.2-3B) model into a 1B (meta-llama/Llama-3.2-1B) model.
What is Knowledge Distillation?
Knowledge distillation (KD) transfers the dark knowledge of a high-capacity teacher model to a smaller student by minimizing the divergence between their predicted distributions. The student learns from both the ground-truth labels (Cross-Entropy loss, CE) and the soft targets of the teacher (Kullback-Leibler loss, KD):
where \(\alpha\) is the kd_ratio, \(T\) softmax temperature and the labels. For the arguments p:
.
Prepare the YAML Config
A ready-to-use example is provided at
examples/llm_kd/llama3_2/llama3_2_1b_kd.yaml. Important sections:
model– the student to be fine-tuned (1 B parameters in the example)teacher_model– a larger frozen model used for supervision (7 B)kd_ratio– blend between CE and KD losstemperature– softens probability distributions before KL-divergencepeft– optional LoRA config (commented). Uncomment to train only a handful of parameters.
Feel free to tweak these values as required.
Example YAML
Current Limitations
- Pipeline parallelism (
pp_size > 1) is not yet supported – planned for a future release. - Distilling Vision-Language models (
vlmrecipe) is currently not supported. - Student and teacher models must share the same tokenizer for now; support for different tokenizers will be added in the future.
Launch Training
Single-GPU Quick Run
Multi-GPU (Single Node)
Slurm Cluster
The CLI seamlessly submits Slurm jobs when a slurm section is added to the
YAML. Refer to docs/guides/installation.md for cluster instructions.
Monitoring
Metrics such as train_loss, kd_loss, learning_rate and tokens/sec are logged to WandB when the corresponding section is enabled.
Checkpoints and Inference
- Checkpoints are written under the directory configured in the
checkpoint.checkpoint_dirfield at everyckpt_every_steps. - The final student model is saved according to the
checkpointsection (e.g.,model_save_format: safetensors, consolidated weights ifsave_consolidated: finalorsave_consolidated: every).
Load the distilled model: