Knowledge Distillation with NeMo AutoModel

View as Markdown

This guide walks through fine-tuning a student LLM with the help of a larger teacher model using the kd (knowledge distillation) recipe.

In particular, we will show how to distill a 3B (meta-llama/Llama-3.2-3B) model into a 1B (meta-llama/Llama-3.2-1B) model.

What is Knowledge Distillation?

Knowledge distillation (KD) transfers the dark knowledge of a high-capacity teacher model to a smaller student by minimizing the divergence between their predicted distributions. The student learns from both the ground-truth labels (Cross-Entropy loss, CE) and the soft targets of the teacher (Kullback-Leibler loss, KD):

L=(1α)LCE(ps,y)+αKL(ps,pt) \mathcal{L} = (1-\alpha) \cdot \mathcal{L}_{\textrm{CE}}(p^{s}, y) + \alpha \cdot \mathcal{KL}(p^{s}, p^{t})

where \(\alpha\) is the kd_ratio, \(T\) softmax temperature and yy the labels. For the arguments p: ps=softmax(zs,T)p^{s} = softmax(z^{s}, T).

Prepare the YAML Config

A ready-to-use example is provided at examples/llm_kd/llama3_2/llama3_2_1b_kd.yaml. Important sections:

  • model – the student to be fine-tuned (1 B parameters in the example)
  • teacher_model – a larger frozen model used for supervision (7 B)
  • kd_ratio – blend between CE and KD loss
  • temperature – softens probability distributions before KL-divergence
  • peftoptional LoRA config (commented). Uncomment to train only a handful of parameters.

Feel free to tweak these values as required.

Example YAML

1# Example config for knowledge distillation fine-tuning
2# Run with:
3# automodel examples/llm_kd/llama3_2/llama3_2_1b_kd.yaml
4
5step_scheduler:
6 global_batch_size: 32
7 local_batch_size: 1
8 ckpt_every_steps: 200
9 val_every_steps: 100 # will run every x number of gradient steps
10 num_epochs: 2
11
12dist_env:
13 backend: nccl
14 timeout_minutes: 1
15
16rng:
17 _target_: nemo_automodel.components.training.rng.StatefulRNG
18 seed: 1111
19 ranked: true
20
21# Student
22model:
23 _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
24 pretrained_model_name_or_path: meta-llama/Llama-3.2-1B
25 torch_dtype: bf16
26
27# Teacher
28teacher_model:
29 _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
30 pretrained_model_name_or_path: meta-llama/Llama-3.2-3B
31 torch_dtype: bf16
32
33checkpoint:
34 enabled: true
35 checkpoint_dir: checkpoints/
36 model_save_format: safetensors
37 save_consolidated: false
38
39distributed:
40 strategy: fsdp2
41 dp_size: null
42 tp_size: 1
43 cp_size: 1
44 pp_size: 1
45 sequence_parallel: false
46
47# PEFT can be enabled by uncommenting below – student weights will remain small
48# peft:
49# _target_: nemo_automodel.components._peft.lora.PeftConfig
50# target_modules: '*_proj'
51# dim: 16
52# alpha: 32
53# use_triton: true
54
55loss_fn:
56 _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
57
58# Knowledge-distillation hyper-params
59kd_ratio: 0.5 # 0 → pure CE, 1 → pure KD
60kd_loss_fn:
61 _target_: nemo_automodel.components.loss.kd_loss.KDLoss
62 ignore_index: -100
63 temperature: 1.0
64 fp32_upcast: true
65
66# Optimizer
67optimizer:
68 _target_: torch.optim.Adam
69 betas: [0.9, 0.999]
70 eps: 1e-8
71 lr: 1.0e-5
72 weight_decay: 0
73
74# Dataset / Dataloader
75dataset:
76 _target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
77 dataset_name: rajpurkar/squad
78 split: train
79
80dataloader:
81 _target_: torchdata.stateful_dataloader.StatefulDataLoader
82 collate_fn: nemo_automodel.components.datasets.utils.default_collater
83 shuffle: false
84
85validation_dataset:
86 _target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag
87 path_or_dataset: rowan/hellaswag
88 split: validation
89 num_samples_limit: 64
90
91validation_dataloader:
92 _target_: torchdata.stateful_dataloader.StatefulDataLoader
93 collate_fn: nemo_automodel.components.datasets.utils.default_collater

Current Limitations

  • Pipeline parallelism (pp_size > 1) is not yet supported – planned for a future release.
  • Distilling Vision-Language models (vlm recipe) is currently not supported.
  • Student and teacher models must share the same tokenizer for now; support for different tokenizers will be added in the future.

Launch Training

Single-GPU Quick Run

$# Runs on a single device of the current host
$automodel examples/llm_kd/llama3_2/llama3_2_1b_kd.yaml

Multi-GPU (Single Node)

$# Leverage all GPUs on the local machine
$torchrun --nproc-per-node $(nvidia-smi -L | wc -l) \
> nemo_automodel/recipes/llm/kd.py \
> -c examples/llm_kd/llama3_2/llama3_2_1b_kd.yaml

Slurm Cluster

The CLI seamlessly submits Slurm jobs when a slurm section is added to the YAML. Refer to docs/guides/installation.md for cluster instructions.

Monitoring

Metrics such as train_loss, kd_loss, learning_rate and tokens/sec are logged to WandB when the corresponding section is enabled.

Checkpoints and Inference

  • Checkpoints are written under the directory configured in the checkpoint.checkpoint_dir field at every ckpt_every_steps.
  • The final student model is saved according to the checkpoint section (e.g., model_save_format: safetensors, consolidated weights if save_consolidated: true).

Load the distilled model:

1import nemo_automodel as am
2student = am.NeMoAutoModelForCausalLM.from_pretrained("checkpoints/final")
3print(student("Translate to French: I love coding!").text)