Cross-Tokenizer (X-Token) Off-Policy Distillation#
NeMo RL supports off-policy distillation between a student and a teacher that do not share a tokenizer — for example, distilling a Qwen3-4B teacher into a Llama-3.2-1B student. Cross-tokenizer (“x-token”) distillation handles the vocabulary mismatch by routing student logits through a precomputed projection matrix that maps each student token to the teacher tokens it most plausibly corresponds to, projecting the student into the teacher’s vocab space so the two distributions can be compared.
This guide explains how to:
Produce the projection matrix from a (student, teacher) tokenizer pair
Launch a distillation run that consumes it
How it works#
A full run has two phases. The three prep steps are offline data prep —
small CLI tools you run once per (student, teacher) pair — and the result is a
single .pt file. The final step is the actual distillation training loop.
┌──────────────────────────────────────────────┐
│ Offline projection-matrix preparation │
│ │
│ ┌────────────────────────────────────┐ │
(student, teacher) │ │ 1. minimal_projection_via_ │ │
tokenizers ────▶│ │ multitoken.py │ │
│ │ — multi-token mappings │ │
│ └─────────────────┬──────────────────┘ │
│ │ │
│ ┌─────────────────▼──────────────────┐ │
│ │ 2. (optional) reapply_exact_map.py │ │
│ │ — pin exact 1-to-1 matches │ │
│ └─────────────────┬──────────────────┘ │
│ │ │
│ ┌─────────────────▼──────────────────┐ │
│ │ 3. sort_and_cut_projection_matrix │ │
│ │ .py — trim to runtime top_k │ │
│ └─────────────────┬──────────────────┘ │
└────────────────────│─────────────────────────┘
│
▼ projection_matrix.pt
┌────────────────────────────────────────────────────┐
│ 4. examples/ │
│ run_xtoken_off_policy_distillation.py │
│ — align student & teacher tokens, then │
│ teacher forward + student forward, │
│ then x-token KD loss │
└────────────────────────────────────────────────────┘
The projection matrix is a sparse [V_student, top_k] tensor that the
training-time loss multiplies against the student logits to project them into
the teacher’s vocab space.
Each row of the matrix holds the weights W_{s,t} that distribute a student
token s ∈ V_S over the teacher tokens t ∈ V_T it corresponds to. Tokens
shared by both vocabularies map 1-to-1 (e.g., _the, _cat, _run), while a
student token that the teacher splits into pieces spreads its weight across
those pieces (e.g., 201 → 2, 0, 1). Rows are trimmed to the runtime
top_k in Step 3, so low-weight tail entries are dropped (hatched cell).
Which prep steps are essential?#
Of the three prep steps, Step 1 (multi-token mappings) and
Step 3 (sort and trim) are required — Step 1 builds the cross-vocab
mapping itself, and Step 3 produces the runtime-format .pt the training
loss expects. Step 2 (reapply exact map) is optional and pins exact
1-to-1 token mappings on top of Step 1, but we found the best results
on this branch by running Steps 1 → 2 → 3.
Quickstart — single command#
For the typical case, tools/x_token/build_projection_matrix.sh chains
the prep steps with auto-derived intermediate paths:
./tools/x_token/build_projection_matrix.sh \
--student-model meta-llama/Llama-3.2-1B \
--teacher-model Qwen/Qwen3-4B \
--runtime-top-k 4
The wrapper writes the final matrix to
cross_tokenizer_data/projection_matrix_<student>_<teacher>_top<N>.pt
(override with --final-output). Pass --skip-exact-map to skip the
optional Step 2, or --no-{scale-trick,reverse-pass,special-token-mapping}
to tweak Step 1 defaults. Run ./tools/x_token/build_projection_matrix.sh --help for the full flag list.
The per-step recipes below are for advanced customization (non-default weight thresholds, hand-picked intermediate filenames, etc.).
Backend and scope#
DTensor V2 only. Set
policy.dtensor_cfg.enabled=trueandpolicy.dtensor_cfg._v2=true. The Megatron policy worker is not wired for cross-tokenizer distillation.Teacher logits travel via CUDA IPC, so student and teacher policies must be colocated on the same node. No remote-Ray transport for x-token logits.
Future work will ease these requirements. A transport such as TransferQueue, for instance, would carry teacher logits across nodes — removing the colocation requirement and the dependence on CUDA IPC.
Step 1 — Build multi-token mappings#
Many student tokens (e.g., "12") tokenize into multiple teacher tokens
(e.g., "1", "2"). minimal_projection_via_multitoken.py walks the
student vocab, re-tokenizes each token with the teacher tokenizer, and adds
weighted entries to the projection. With --enable-reverse-pass it also
does the symmetric teacher → student walk.
uv run python -m tools.x_token.minimal_projection_via_multitoken \
--student-model "meta-llama/Llama-3.2-1B" \
--teacher-model "Qwen/Qwen3-4B" \
--top-k 32 \
--enable-scale-trick \
--enable-reverse-pass \
--enable-special-token-mapping
Output: cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_32_double_special.pt.
Pass --num-examples 50 to print a sample of student→teacher mappings after
the matrix is built — useful for spot-checking that special tokens, numerals,
and punctuation map to sensible teacher tokens.
When --enable-scale-trick is set, the script records enable_scale_trick=True
in the saved .pt so Step 3 can auto-enable --preserve_last.
Step 2 (optional) — Reapply exact-token map#
Some token pairs are literally identical (e.g., common punctuation, single
ASCII characters). reapply_exact_map.py pins those to 1-to-1 mappings with
weight 1.0, overwriting whatever Step 1 produced for them.
uv run python -m tools.x_token.reapply_exact_map \
--student-model "meta-llama/Llama-3.2-1B" \
--teacher-model "Qwen/Qwen3-4B" \
--initial-projection-path cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_32_double_special.pt
Output is written next to the input as <basename>_exact_map_remapped.pt.
Step 3 — Sort and trim to runtime top_k#
The training loss only needs a small top_k per row (typical: 4–8). This
step sorts each row by weight and trims to the chosen runtime cap.
uv run python -m tools.x_token.sort_and_cut_projection_matrix \
--initial-projection-path cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_32_double_special_exact_map_remapped.pt \
--top_k 4 \
--output_path cross_tokenizer_data/projection_matrix_llama_qwen_top4.pt
--preserve_last is argparse.BooleanOptionalAction with default None. When
unspecified, the script reads enable_scale_trick from the input matrix’s
metadata (set in Step 1) and auto-enables preservation of the last column
slot. Pass --preserve_last or --no-preserve_last to override.
Step 4 — Launch x-token distillation#
The training entrypoint is examples/run_xtoken_off_policy_distillation.py with the
exemplar config at examples/configs/xtoken_off_policy_distillation.yaml. The exemplar
defaults to Llama-3.2-1B (student) ← Qwen3-4B (teacher) and the P-KL loss
mode. For data it points data.train.data_files at the ungated, CC-BY-4.0
NVIDIA Nemotron-Pretraining-Specialized-v1.1 corpus
(Nemotron-Pretraining-Formal-Logic subset) over hf://, so the recipe runs
out of the box with no auth or extra setup. The projection_matrix_path below
points at the cross_tokenizer_data/ directory that Steps 1–3 create, so run
those first (or the build_projection_matrix.sh wrapper). Override paths via
Hydra CLI:
uv run python examples/run_xtoken_off_policy_distillation.py \
--config examples/configs/xtoken_off_policy_distillation.yaml \
loss_fn.projection_matrix_path=cross_tokenizer_data/projection_matrix_llama_qwen_top4.pt \
cluster.gpus_per_node=8 \
cluster.num_nodes=1
The exemplar config keeps only loss_fn.projection_matrix_path as null, so
the projection matrix must always be supplied at the CLI — this keeps the
config reusable across (student, teacher) pairs. data.train.data_files
already points at the default NVIDIA corpus described above; override it only
to train on your own .arrow/.parquet/.json/.txt corpus.
Loss-mode knobs#
loss_fn has two flags that pick between three behaviors:
|
|
Behavior |
|---|---|---|
|
(inert) |
P-KL — full-vocab teacher logits via CUDA IPC; the loss derives a microbatch-global top-k inside, projects the student into teacher vocab via the projection matrix, and chunk-averages KL on the top-k subset. CE term is added. |
|
|
Gold loss — split the vocab into an exact-token-mapped common set (KL) and an uncommon tail (sorted L1). |
|
|
H-KL (gold + xtoken) — same as gold, but relax the exact-map threshold to |
Other relevant fields:
loss_fn.temperature— softmax temperature applied symmetrically to student and teacher logits before KL.loss_fn.vocab_topk— microbatch-global top-k size for the P-KL path (inert whengold_loss=true).loss_fn.uncommon_topk— cap on the L1 uncommon-tail sort in the gold path (defaults to 8192).loss_fn.reverse_kl— computeKL(student || teacher)instead ofKL(teacher || student).
Results — 100-step P-KL run#
A P-KL run (Llama-3.2-1B student ← Qwen3-4B teacher; default config — global batch 96, micro-batch 1, sequence length 2048, 100 steps, 2 nodes — on the default Nemotron-Pretraining-Specialized-v1.1 / Formal-Logic corpus) shows the distillation objective converging and the student tracking the teacher more closely over training:
Loss falls from ≈1.51 to ≈0.78.
KL loss falls from ≈2.67 to ≈0.78 — the projected student distribution moves toward the teacher’s.
CE loss falls from ≈0.75 to ≈0.39.
Top-1 accuracy rises from ≈0.82 to ≈0.88.
Throughput and memory#
Measured on the same run (per training step, P-KL, micro-batch 1, sequence length 2048):
Metric |
Value |
|---|---|
Mean step time |
4.07 s (min 3.66 s) |
Training throughput |
≈48k valid tokens/s (global batch ÷ mean step time) |
Peak GPU memory |
29.5 GB per GPU |
Teacher-logit IPC tray |
≈0.6 GB per sample-step — |
The full-vocab teacher logits never cross the network: the producer publishes
a single rank-level [B_r, T_t, V_t] bf16 tray and hands the student a CUDA
IPC handle to it (same node), so the per-step transport cost is the ≈0.6 GB
tray allocation rather than a host round-trip.
Where files live#
Stage |
Tool |
Default output |
|---|---|---|
Build multi-token |
|
|
Reapply exact map |
|
|
Sort and trim |
|
|
Train |
|
per the run’s |