Train an EAGLE Drafter for Speculative Decoding — End-to-End Guide

View as Markdown

A step-by-step guide for training an EAGLE speculative decoding drafter to accelerate LLM inference using NeMo AutoModel.


What is EAGLE Speculative Decoding?

Large language models generate text one token at a time — each token requires a full forward pass through the entire model. Speculative decoding speeds this up by pairing the large target model with a small, fast drafter model. The drafter guesses multiple tokens ahead; the target model then verifies them all in a single forward pass, accepting correct guesses and rejecting wrong ones. The output is mathematically identical to running the target model alone, but 2-3x faster.

EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) is a family of speculative decoding methods. NeMo AutoModel supports three variants:

VariantRecipeDescription
EAGLE-1train_eagle1Lightweight 1-layer draft transformer; learns to predict target hidden states + next token
EAGLE-2train_eagle2Same architecture as EAGLE-1 (alias recipe)
EAGLE-3train_eagle3Advanced drafter with test-time training (TTT) unroll and vocabulary mapping; best speed

The Task

We train an EAGLE-3 drafter for Llama 3.1 8B Instruct on the PerfectBlend dataset — a chat corpus whose assistant turns were generated by the same Llama 3.1 8B model, ensuring distribution alignment between training data and target.

After training, we serve the target + drafter together via SGLang for accelerated inference.

Guide Overview

StepDescription
Step 0Environment setup
Step 1Understand EAGLE architecture
Step 2Prepare the training dataset
Step 3Configure and launch EAGLE-3 training
Step 4Monitor training and inspect checkpoints
Step 5Serve with SGLang
Step 6(Bonus) Train an EAGLE-1 drafter

Hardware Requirements

SetupTarget ModelGPUsTraining Time
MVP (quick test)Llama 3.2 1B1x A100 80 GB~10 min (1 epoch, 1k samples)
ProductionLlama 3.1 8B Instruct8x A100 80 GB~2 h (1 epoch, 200k samples)

The target model is loaded in full precision and frozen during training. Only the small drafter model (a few transformer layers) has trainable parameters, so GPU memory is dominated by the target model size.


Step 0 — Environment Setup

This guide runs inside the NeMo AutoModel Docker container:

$docker run -it --rm --gpus all --ipc=host --network host \
> -v $(pwd):/workspace \
> nvcr.io/nvidia/nemo-automodel:26.06.00
$huggingface-cli login
$cd /opt/Automodel

For SGLang serving (Step 5), install it in the same environment:

$uv pip install "sglang>=0.5.9"

Step 1 — Understand EAGLE Architecture

How EAGLE-3 Works

EAGLE-3 pairs a frozen target LLM with a small trainable drafter. During training, the drafter learns to predict what the target model would produce next, using a technique called test-time training (TTT) unroll:

Target (frozen) Drafter (trainable)
┌──────────────┐ ┌──────────────┐
│ Llama 3.1 8B │ ──────> │ 2-layer │
│ │ hidden │ transformer │
│ Full model │ states │ + fc fusion │
│ (frozen) │ │ + lm_head │
└──────────────┘ └──────────────┘
predict next token
+ hidden states

Key components:

  • Target model: The full LLM (e.g., Llama 3.1 8B), completely frozen. Provides hidden states from selected intermediate layers as auxiliary inputs to the drafter.
  • Draft model: A shallow transformer (typically 2 layers) with:
    • A fusion layer (fc) that combines auxiliary hidden states from 3 target layers
    • Its own attention layers, MLP, and layer norm
    • A smaller output vocabulary (e.g., 8192 or 32000 tokens instead of 128k) to reduce compute
  • TTT unroll: The drafter runs multiple autoregressive steps (default 4) during training, with exponentially decaying loss weights (0.8^i). This teaches the drafter to make multi-step predictions — exactly what speculative decoding needs.

EAGLE-3.1 Drafter Toggles

The same train_eagle3 recipe supports the EAGLE-3.1 drafter variant via two optional flags in recipe_args. Both default to false, so existing EAGLE-3 configs and checkpoints behave identically. Setting them applies the EAGLE-3.1 architectural changes from vllm-project/vllm#42764 to the Llama-style draft. The MLA-backbone community release lightseekorg/kimi-k2.6-eagle3.1-mla is a separate architecture (Eagle3DeepseekV2ForCausalLM) and is not produced by this recipe.

FlagEffect
fc_normApply an independent RMSNorm to each of the num_aux_hidden_states (3 by default) auxiliary hidden-state chunks before they are concatenated and projected by model.fc. Stored as an nn.ModuleList with on-disk keys model.fc_norm.0.weight, model.fc_norm.1.weight, … matching vLLM’s layout so checkpoints load directly.
norm_outputRoute the existing final RMSNorm (model.norm) over the per-step hidden state returned by the drafter so the next TTT step (and lm_head) consume the post-norm state instead of the raw decoder output. Adds no parameters.

Together they remove the “attention drift” pattern (loss of focus on sink tokens at deeper speculation depths) reported by the EAGLE-3.1 paper and let the drafter behave more like a recurrently applied module than a stack of extra layers bolted onto the target.

1recipe_args:
2 # ... standard EAGLE-3 fields ...
3 fc_norm: true
4 norm_output: true

How EAGLE-1 Differs

EAGLE-1 is simpler: it uses a single transformer layer, predicts the full vocabulary, and trains with a combined loss of MSE on hidden states (hidden_loss_weight) and cross-entropy on tokens (token_loss_weight). No TTT unroll, no vocabulary mapping.


Step 2 — Prepare the Training Dataset

Data format

EAGLE training expects chat data in the OpenAI messages format — either JSONL files or HuggingFace datasets with a messages column:

1{"messages": [
2 {"role": "system", "content": "You are a helpful assistant."},
3 {"role": "user", "content": "What is the capital of France?"},
4 {"role": "assistant", "content": "The capital of France is Paris."}
5]}

For best results, the assistant turns in your training data should come from the same model you’ll use as the target at inference time. The PerfectBlend dataset already has answers regenerated by Llama 3.1 8B Instruct:

$python -c "
>from datasets import load_dataset
>ds = load_dataset(
> 'frankleeeee/PerfectBlend-Regenerated-Llama-3.1-8B-Instruct',
> split='train[:5]'
>)
>print(f'Columns: {ds.column_names}')
>print(f'Sample conversation:')
>for msg in ds[0]['conversations'][:3]:
> role = msg['role']
> text = msg['content'][:80]
> print(f' [{role}] {text}...')
>"

Expected output:

Columns: ['conversations']
Sample conversation:
[system] You are a helpful assistant....
[user] What are the main differences between Python 2 and Python 3?...
[assistant] Here are the key differences between Python 2 and Python 3:
1. **P...

PerfectBlend uses a conversations column, but ChatDataset expects messages. Rename the column before training:

$python -c "
>import pandas as pd
>from pathlib import Path
>src = Path('<download_dir>')
>dst = Path('./data/perfectblend_renamed')
>dst.mkdir(parents=True, exist_ok=True)
>for f in sorted(src.glob('train-*.parquet')):
> df = pd.read_parquet(f)
> df = df.rename(columns={'conversations': 'messages'})
> df.to_parquet(dst / f.name, index=False)
>print('Done. Point train_data_path at:', dst)
>"

Option B: Regenerate answers from your target model

If you have a chat dataset whose answers were generated by a different model, you can regenerate them using your target. This aligns the training data distribution with the model the drafter will actually assist at inference time.

Step B.1 — Start the target server (in one shell):

$python -m sglang.launch_server \
> --model-path meta-llama/Llama-3.1-8B-Instruct \
> --port 30000

Wait for Uvicorn running on http://0.0.0.0:30000 before proceeding.

Step B.2 — Regenerate (in another shell):

$python -m nemo_automodel.components.speculative.regenerate \
> --input-data Aeala/ShareGPT_Vicuna_unfiltered \
> --output-dir ./regenerated/sharegpt_llama31_8b \
> --target-server http://localhost:30000/v1 \
> --model meta-llama/Llama-3.1-8B-Instruct \
> --concurrency 64 \
> --shard-size 1000

For each sample, the script:

  1. Loads the conversation from the input dataset
  2. Drops the trailing assistant turn, keeping the user prompt context
  3. Calls the target server to generate a new assistant response
  4. Saves the rebuilt conversation to parquet shards

The output directory contains parquet files with a messages column — ready for EAGLE training. The script is resumable: rerun with --resume to skip completed shards.

FlagDefaultNotes
--concurrency32In-flight requests; raise to saturate the server
--shard-size1000Rows per parquet file
--temperature0.0Greedy by default (recommended for EAGLE)
--max-new-tokens1024Cap per-answer length
--splittrainSupports HF slice syntax, e.g., train[:10000]

Step 3 — Configure and Launch EAGLE-3 Training

YAML config

Save the following as eagle3_llama8b.yaml:

1recipe: TrainEagle3Recipe
2
3dist_env:
4 backend: nccl
5 timeout_minutes: 60
6
7recipe_args:
8 target_model_name_or_path: meta-llama/Llama-3.1-8B-Instruct
9
10 # Point to your training data (HF dataset id, local parquet dir, or JSONL)
11 train_data_path: ./data/perfectblend_renamed
12 val_data_path: null
13
14 # Slice to 200k samples for a ~2h training run
15 train_split: "train[:200000]"
16 val_split: null
17
18 output_dir: ./outputs/eagle3_llama8b
19 seq_length: 2048
20 micro_batch_size: 1
21 grad_accumulation_steps: 4 # effective batch = 8 GPUs * 1 * 4 = 32
22 num_workers: 4
23 num_epochs: 1
24
25 # EAGLE-3 specific
26 ttt_steps: 4 # TTT unroll depth (higher = better but slower)
27 draft_vocab_size: 32000 # smaller vocab = faster drafter
28
29 # The drafter copies the target's embedding table at init; this flag
30 # freezes those copied weights so only the draft transformer layers
31 # and lm_head are trained.
32 freeze_embeddings: true
33 shuffle_seed: 42
34 log_every_steps: 20
35 max_grad_norm: 1.0
36
37optimizer:
38 lr: 2.0e-4
39 betas: [0.9, 0.95]
40 weight_decay: 0.0
41 warmup_ratio: 0.05 # 5% warmup
42 min_lr_ratio: 0.1
43
44checkpoint:
45 enabled: true
46 checkpoint_dir: ./outputs/eagle3_llama8b/checkpoints
47 # The recipe defaults to safetensors + consolidated; these lines are
48 # shown explicitly for clarity but can be omitted.
49 model_save_format: safetensors
50 save_consolidated: true

Config field reference

FieldWhat It Does
target_model_name_or_pathHuggingFace model ID for the frozen target LLM
train_data_pathPath to chat data (HF dataset id, parquet dir, or JSONL)
train_splitOptional HF slice syntax to limit data size
seq_lengthContext window length (1024 for quick tests, 2048 for production)
micro_batch_sizePer-GPU batch size
grad_accumulation_stepsGradient accumulation for larger effective batches
ttt_stepsTTT unroll depth; 4 is the default, cost is linear per step
draft_vocab_sizeDraft output vocabulary size; smaller = faster inference
freeze_embeddingsFreeze the embedding table copied from the target so only draft layers train (recommended true)
target_attn_implementationOptional attention backend for the frozen target (e.g. sdpa); defaults to HF auto-select. Set sdpa if the target’s FlashAttention path is broken on your build (e.g. the Qwen3 FA2 s_aux crash)
fc_normEAGLE-3.1: per-chunk independent RMSNorm (ModuleList) on auxiliary hidden states before the fc projection (default false)
norm_outputEAGLE-3.1: feed the post-model.norm hidden state into the next TTT step and lm_head (default false)
warmup_ratioFraction of total steps for LR warmup

Launch training

Multi-GPU (8x A100, production):

$torchrun --nproc-per-node=8 \
> -m nemo_automodel.recipes.llm.train_eagle3 \
> eagle3_llama8b.yaml

Single-GPU (quick test with Llama 3.2 1B):

For a quick test, use the MVP config with Llama 3.2 1B and a small dataset:

$python -m nemo_automodel.recipes.llm.train_eagle3 \
> examples/speculative/eagle3/llama_eagle3_mvp.yaml

For GPUs with FlashAttention support, add draft_attn_implementation: flash_attention_2 to recipe_args for faster training. See llama_eagle3_mvp_flash_attn.yaml for a complete example.


Step 4 — Monitor Training and Inspect Checkpoints

What to watch

Training loss should drop steadily. Here is a sample log from the PerfectBlend 200k run on 8x A100:

[epoch 0] step 20 | loss 3.4521 | grad_norm 12.84 | lr 8.00e-05 | tokens/s 2841
[epoch 0] step 40 | loss 2.8973 | grad_norm 8.31 | lr 1.60e-04 | tokens/s 3102
[epoch 0] step 100 | loss 2.1245 | grad_norm 5.62 | lr 2.00e-04 | tokens/s 3254
[epoch 0] step 500 | loss 1.5832 | grad_norm 3.17 | lr 1.98e-04 | tokens/s 3198
[epoch 0] step 1000 | loss 1.3401 | grad_norm 2.45 | lr 1.92e-04 | tokens/s 3221
[epoch 0] step 3000 | loss 1.1056 | grad_norm 1.89 | lr 1.58e-04 | tokens/s 3245
[epoch 0] step 6000 | loss 0.9823 | grad_norm 1.52 | lr 0.42e-04 | tokens/s 3230

Checkpoint layout

Each checkpoint is saved under <checkpoint_dir>/epoch_<E>_step_<S>/:

outputs/eagle3_llama8b/checkpoints/
epoch_0_step_1000/
config.json # draft model config
draft_model.pt # draft model weights
eagle3_meta.pt # token mapping (selected_token_ids + mask)
optimizer.pt # Adam state (for resume)
scheduler.pt # LR scheduler state
rng/ # distributed RNG state
epoch_0_step_2000/
...
LATEST -> epoch_0_step_6250

Resume from checkpoint

If training is interrupted, resume from the latest checkpoint:

1checkpoint:
2 restore_from: LATEST

Or point to a specific checkpoint:

1checkpoint:
2 restore_from: epoch_0_step_3000

Step 5 — Serve with SGLang

The serve_sglang helper converts the training checkpoint into an HF/SGLang-compatible format and launches the server in one command.

Launch the server

$python -m nemo_automodel.components.speculative.serve_sglang \
> --target meta-llama/Llama-3.1-8B-Instruct \
> --draft ./outputs/eagle3_llama8b/checkpoints/LATEST \
> --algorithm EAGLE3 \
> --num-steps 3 \
> --topk 1 \
> --num-draft-tokens 4 \
> --port 30000

On first launch, the helper:

  1. Loads draft_model.pt and eagle3_meta.pt from the checkpoint
  2. Rewrites the architecture name for SGLang compatibility (LlamaEagle3DraftModel -> LlamaForCausalLMEagle3)
  3. Exports model.safetensors and speculative_token_map.pt into a model/ subdirectory
  4. Launches SGLang with the correct speculative decoding flags

Serving parameters

FlagDefaultNotes
--algorithmEAGLE3 for EAGLE-3 drafters, EAGLE for EAGLE-1/2
--num-steps3Speculative steps per draft iteration
--topk1Branching factor for tree search
--num-draft-tokens4Budget of draft tokens per branch
--dtypeautoMust match training dtype (e.g., bfloat16)
--tp-size1Tensor parallelism (shards the target model only)
--print-onlyInspect the resolved command without launching

Pass extra SGLang flags after --:

$python -m nemo_automodel.components.speculative.serve_sglang \
> --target meta-llama/Llama-3.1-8B-Instruct \
> --draft ./outputs/eagle3_llama8b/checkpoints/LATEST \
> --algorithm EAGLE3 \
> -- --enable-torch-compile --schedule-conservativeness 1.2

Smoke-test the server

Once you see Uvicorn running on http://0.0.0.0:30000, test it:

$curl http://localhost:30000/generate \
> -H "Content-Type: application/json" \
> -d '{
> "text": "Hello, my name is",
> "sampling_params": {"max_new_tokens": 64}
> }'

Expected output:

1{
2 "text": "Hello, my name is Sarah and I am a 25-year-old software engineer...",
3 "meta_info": {
4 "prompt_tokens": 6,
5 "completion_tokens": 64,
6 "accept_length_per_step": 3.2
7 }
8}

The accept_length_per_step metric shows how many tokens the target model accepts per speculative step on average. Higher is better — a value of 3.0+ indicates the drafter is accurately predicting the target’s behavior.

OpenAI-compatible endpoint

SGLang also exposes an OpenAI-compatible API:

$curl http://localhost:30000/v1/chat/completions \
> -H "Content-Type: application/json" \
> -d '{
> "model": "meta-llama/Llama-3.1-8B-Instruct",
> "messages": [
> {"role": "user", "content": "Explain speculative decoding in one paragraph."}
> ],
> "max_tokens": 256
> }'

Expected output:

1{
2 "choices": [{
3 "message": {
4 "role": "assistant",
5 "content": "Speculative decoding is a technique for accelerating autoregressive language model inference. It works by using a small, fast \"draft\" model to predict multiple future tokens, which are then verified in parallel by the larger \"target\" model in a single forward pass. Tokens that match the target model's predictions are accepted, while incorrect tokens are rejected and regenerated. Because verification is cheaper than sequential generation (it processes all candidate tokens simultaneously), the overall throughput increases significantly — typically 2-3x — while producing output that is mathematically identical to running the target model alone."
6 }
7 }],
8 "usage": {
9 "prompt_tokens": 14,
10 "completion_tokens": 112
11 }
12}

Step 6 — (Bonus) Train an EAGLE-1 Drafter

EAGLE-1 is simpler and faster to train, making it a good starting point for experimentation. It uses a single transformer layer and trains with a combined hidden-state MSE + token cross-entropy loss.

YAML config

Save as eagle1_llama8b.yaml:

1recipe: TrainEagle1Recipe
2
3dist_env:
4 backend: nccl
5 timeout_minutes: 30
6
7recipe_args:
8 target_model_name_or_path: meta-llama/Llama-3.1-8B-Instruct
9 train_data_path: ./data/perfectblend_renamed
10 val_data_path: null
11 train_split: "train[:200000]"
12 val_split: null
13 output_dir: ./outputs/eagle1_llama8b
14 seq_length: 2048
15 micro_batch_size: 1
16 grad_accumulation_steps: 4
17 num_workers: 4
18 num_epochs: 1
19
20 # EAGLE-1 specific
21 draft_num_hidden_layers: 1 # number of draft transformer layers
22 hidden_loss_weight: 1.0 # MSE loss on hidden states
23 token_loss_weight: 0.1 # cross-entropy loss on tokens
24
25 freeze_embeddings: true
26 shuffle_seed: 42
27 log_every_steps: 10
28 max_grad_norm: 1.0
29
30optimizer:
31 lr: 1.0e-4
32 betas: [0.9, 0.95]
33 weight_decay: 0.0
34
35checkpoint:
36 enabled: true
37 checkpoint_dir: ./outputs/eagle1_llama8b/checkpoints
38 # Defaults to safetensors + consolidated; can be omitted.
39 model_save_format: safetensors
40 save_consolidated: true

Launch

$torchrun --nproc-per-node=8 \
> -m nemo_automodel.recipes.llm.train_eagle1 \
> eagle1_llama8b.yaml

Serve

Use --algorithm EAGLE (not EAGLE3) for EAGLE-1/2 drafters:

$python -m nemo_automodel.components.speculative.serve_sglang \
> --target meta-llama/Llama-3.1-8B-Instruct \
> --draft ./outputs/eagle1_llama8b/checkpoints/LATEST \
> --algorithm EAGLE \
> --num-steps 3 --topk 1 --num-draft-tokens 4 \
> --port 30000

EAGLE-1 vs EAGLE-3

EAGLE-1EAGLE-3
Draft layers1 (configurable)2 (with aux fusion)
Training objectiveHidden MSE + token CETTT unroll with decay
VocabularyFull target vocabReduced (e.g., 8k-32k)
Training speedFasterSlower (due to TTT unroll)
Inference speedupGood (2-2.5x)Better (2.5-3x)
Best forQuick experimentsProduction deployment

Example Configs Reference

ConfigTargetVariantNotes
llama_eagle3_mvp.yamlLlama 3.2 1BEAGLE-3Quick test, single GPU
llama_eagle3_mvp_flash_attn.yamlLlama 3.2 1BEAGLE-3With FlashAttention-2
llama_eagle3_perfectblend.yamlLlama 3.1 8BEAGLE-3Production config, 200k samples
llama_eagle3_1_perfectblend.yamlLlama 3.1 8BEAGLE-3.1Production config with fc_norm + norm_output enabled
llama_eagle1_mvp.yamlLlama 3.2 1BEAGLE-1Quick test, single GPU
llama_eagle2_mvp.yamlLlama 3.2 1BEAGLE-2Same as EAGLE-1 (alias)

Troubleshooting

SymptomFix
OutOfMemoryError during trainingReduce seq_length (1024 instead of 2048) or micro_batch_size
Loss stays flat or NaNCheck max_grad_norm (default 1.0), reduce lr
SGLang model not found errorEnsure --algorithm matches the recipe (EAGLE3 for train_eagle3, EAGLE for train_eagle1/2)
dtype mismatch at servingPass --dtype bfloat16 to match training precision
conversations vs messages column errorRename the column in your dataset (see Step 2 warning)
Checkpoint resume failsUse restore_from: LATEST or the exact subdirectory name like epoch_0_step_1000
Low accept_length_per_step at servingTrain longer, use more data, or try regenerating answers with the target model (Option B in Step 2)