Start a Knowledge Distillation (KD) Customization Job#

Learn how to use the NeMo Microservices Platform to create a Knowledge Distillation (KD) job, transferring knowledge from a large teacher model to a smaller student model using your own dataset.

About Knowledge Distillation#

Knowledge distillation is a technique for transferring knowledge from a large, high-capacity teacher model to a smaller student model. The distilled model (student) often achieves higher accuracy than models trained using standard language modeling loss alone.

KD is useful when you want to deploy smaller models without losing much accuracy compared to a large model.

Prerequisites#

Before starting, make sure you have:

Notes and Limitations#

  • Only logit-pair distillation is currently supported.

  • LoRA adapters can’t be used as teacher models.


Select Teacher and Student Models#

You need two models available as customization targets:

  • Teacher model: A large, fine-tuned model

  • Student model: A smaller model you want to distill knowledge into

Both models must use the same tokenizer. Only GPT-based NeMo 2.0 checkpoints are supported for now.


Select Model#

You can either find an existing customization config to use or create a new one.

Find Available Configs#

Identify what model customization configurations are available that support distillation training. KD customization jobs require a model configuration that supports both training_type of distillation and finetuning_type of all_weights.

  1. Get all customization configurations.

    curl -X GET "${CUSTOMIZER_SERVICE_URL}/v1/customization/configs" \
      -H 'Accept: application/json' | jq
    
    import requests
    
    response = requests.get(
        f"{CUSTOMIZER_SERVICE_URL}/v1/customization/configs",
        headers={"Accept": "application/json"}
    )
    print(response.json())
    
  2. Review the response to find a model configuration that includes distillation in its training_options.

    Example Response
    {
      "object": "list",
      "data": [
       {
        "name": "meta/llama-3.2-1b-instruct@v1.0.0+A100",
        "namespace": "default",
        "training_options": [
          {
            "training_type": "sft",
            "finetuning_type": "lora",
            "num_gpus": 1,
            "num_nodes": 1,
            "tensor_parallel_size": 1,
            "use_sequence_parallel": false
          },
          {
            "training_type": "distillation",
            "finetuning_type": "all_weights",
            "num_gpus": 1,
            "num_nodes": 1,
            "tensor_parallel_size": 1,
            "use_sequence_parallel": false
          }
          ]
        }
    

Create Config#

If no appropriate configuration is available, you can create one that supports distillation training. Here’s how to create a config with distillation support:

curl -X POST \
  "${CUSTOMIZER_SERVICE_URL}/v1/customization/configs" \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
    "name": "llama-3.2-1b-instruct@v1.0.0+A100",
    "namespace": "default",
    "description": "Configuration for Llama 3.2 1B with distillation support",
    "target": "meta/llama-3.2-1b-instruct",
    "training_options": [
       {
          "training_type": "sft",
          "finetuning_type": "lora",
          "num_gpus": 1,
          "tensor_parallel_size": 1,
          "pipeline_parallel_size": 1,
          "use_sequence_parallel": false,
          "micro_batch_size": 1
      },
       {
          "training_type": "distillation",
          "finetuning_type": "all_weights",
          "num_gpus": 1,
          "tensor_parallel_size": 1,
          "pipeline_parallel_size": 1,
          "use_sequence_parallel": false,
          "micro_batch_size": 1
      }
    ],
    "training_precision": "bf16",
    "max_seq_length": 2048
  }' | jq
import requests

url = f"{CUSTOMIZER_SERVICE_URL}/v1/customization/configs"
payload = {
    "name": "llama-3.2-1b-instruct@v1.0.0+A100",
    "namespace": "default",
    "description": "Configuration for Llama 3.2 1B with distillation support",
    "target": "meta/llama-3.2-1b-instruct",
    "training_options": [
        {
            "training_type": "sft",
            "finetuning_type": "lora",
            "num_gpus": 1,
            "tensor_parallel_size": 1,
            "pipeline_parallel_size": 1,
            "use_sequence_parallel": False,
            "micro_batch_size": 1
        },
        {
            "training_type": "distillation",
            "finetuning_type": "all_weights",
            "num_gpus": 1,
            "tensor_parallel_size": 1,
            "pipeline_parallel_size": 1,
            "use_sequence_parallel": False,
            "micro_batch_size": 1
        }
    ],
    "training_precision": "bf16",
    "max_seq_length": 2048
}
headers = {
    "accept": "application/json",
    "Content-Type": "application/json"
}
response = requests.post(url, json=payload, headers=headers)
print(response.json())

For detailed information about creating configs, see Create Customization Config.


Create Datasets#

Prepare your training and validation datasets in the same format required for SFT jobs. The dataset should be the same as (or similar to) the one used to fine-tune the teacher model.

Refer to the Format Training Datasets tutorial for details on dataset structure and upload instructions.


Start Model Customization Job#

Set Hyperparameters#

When creating a KD job, set the following in your job configuration:

  • training_type: distillation

  • finetuning_type: all_weights (the only supported option)

  • distillation.teacher: The name of the teacher Target (must already exist)

Example hyperparameters section:

"hyperparameters": {
    "training_type": "distillation",
    "finetuning_type": "all_weights",
    "epochs": 2,
    "batch_size": 16,
    "learning_rate": 0.0001,
    "distillation": {
        "teacher": "meta/finetuned-llama-3_1-8b@v1"
    }
}

Create and Submit Customization Job#

curl -X POST "${CUSTOMIZER_SERVICE_URL}/v1/customization/jobs" \
  -H "Content-Type: application/json" \
  -d '{
    "target": "meta/llama-3.2-1b-instruct",
    "config": "meta/llama-3.2-1b-instruct@v1.0.0+A100",
    "hyperparameters": {
      "training_type": "distillation",
      "finetuning_type": "all_weights",
      "epochs": 2,
      "batch_size": 16,
      "learning_rate": 0.0001,
      "distillation": {
        "teacher": "<namespace>/<finetuned_model_name>"
      }
    },
    "dataset": {
      "name": "<dataset_name>",
      "namespace": "<namespace>"
    },
    "output_model": "default/my-distilled-3.2-1b@v1"
  }'
import requests

response = requests.post(
    "${CUSTOMIZER_SERVICE_URL}/v1/customization/jobs",
    json={
        "target": "meta/llama-3.2-1b-instruct",  # student
        "config": "meta/llama-3.2-1b-instruct@v1.0.0+A100",  # must support distillation
        "hyperparameters": {
            "training_type": "distillation",
            "finetuning_type": "all_weights",
            "epochs": 2,
            "batch_size": 16,
            "learning_rate": 0.0001,
            "distillation": {
                "teacher": "<namespace>/<finetuned_model_name>"  # teacher
            },
        },
        "dataset": {"name": "<dataset_name>", "namespace": "<namespace>"},
        "output_model": "default/my-distilled-3.2-1b@v1",
    }
)
print(response.json())

Important

The config field must include a version, for example: meta/llama-3.2-1b-instruct@v1.0.0+A100. Omitting the version will result in an error like:

{ "detail": "Version is not specified in the config URN: meta/llama-3.2-1b-instruct" }

You can find the correct config URN (with version) by inspecting the output of the /v1/customization/configs endpoint. Use the name and version fields to construct the URN as name@version.