Train a DSpark Drafter for Speculative Decoding

View as Markdown

A guide for training a DSpark speculative-decoding drafter to accelerate LLM inference with NeMo AutoModel.


What is DSpark?

DSpark is a semi-autoregressive parallel drafter. A parallel backbone proposes every position of a block in a single forward pass, a lightweight serial Markov head injects intra-block token dependency (mitigating the acceptance decay of purely parallel drafters), and a confidence head predicts per-position acceptance probability for scheduled verification. The draft shares and freezes the target’s embed_tokens and lm_head, training only the backbone, the feature projection, the Markov head, and the confidence head.

It follows the same scaffolding as the EAGLE and DFlash recipes: online target hidden-state capture, gradient accumulation, and consolidated-safetensors checkpointing.

Objective

The draft is trained with a three-term, position-decay-weighted objective:

TermMeaning
L_ce (ce_loss_alpha, default 0.1)cross-entropy against the next target token
L_tv (l1_loss_alpha, default 0.9)total-variation distance to the target distribution (a direct acceptance proxy)
L_conf (confidence_head_alpha, default 1.0)BCE training the confidence head against measured acceptance

Positions are weighted by exp(-(k-1)/loss_decay_gamma).

Data

Use a chat dataset of OpenAI-format messages rows. As in the DSpark paper, use Open-PerfectBlend prompts with the responses regenerated by the target model (training is teacher-forced; regenerate before training to avoid a train/inference distribution mismatch). Point recipe_args.train_data_path at the regenerated JSONL or a Hugging Face dataset id.

Run it

Example configs live under examples/speculative/dspark/ (qwen3_0.6b_dspark.yaml, gemma4_12b_dspark.yaml). Multi-GPU defaults to FSDP2 (distributed.strategy: fsdp2); set it to ddp for simple data parallelism.

$torchrun --standalone --nproc_per_node=2 \
> -m nemo_automodel.recipes.llm.train_dspark \
> -c examples/speculative/dspark/qwen3_0.6b_dspark.yaml

Per-step metrics β€” loss, ce_loss, l1_loss, confidence_loss, lr, mem β€” are reduced across data-parallel ranks and written to <output_dir>/dspark_train_metrics.jsonl.

Key config fields

FieldMeaning
target_model_name_or_pathfrozen target (e.g. Qwen/Qwen3-4B)
draft_num_hidden_layersdraft backbone depth (paper: 5)
block_sizetokens drafted per block (paper: 7)
num_anchorsblocks sampled per sequence per step
target_layer_idstarget feature layers fed to the draft (defaults to an even spread)
mask_token_idreserved token id filling non-anchor block positions
markov_rank / markov_head_typeserial head size and variant (vanilla / gated / rnn)
confidence_head_alpha / confidence_head_with_markovconfidence-head weight and conditioning

Supported targets: Qwen3 (dense and MoE) and Gemma4.