nemo_automodel.components.speculative.precompute_eagle3
nemo_automodel.components.speculative.precompute_eagle3
Precompute the EAGLE-3 offline target-output cache (SpecForge “offline” path).
The frozen target model’s per-token supervision is the same on every epoch and
every run, yet the online recipe recomputes it each step. This script runs the
target once over a dataset and writes the supervision to disk; training then
reads it back via cached_target_path instead of loading or running the
target at all.
============================== READ THIS =================================
This is the SpecForge offline training path. It is extremely disk
intensive — aux_hidden_states alone is 3 * target_hidden_size wide,
so an 8B target at seq-len 2048 costs ~80 MB per sample (tens of TB for a
large corpus). Modern practice trains online, where the target forward is
cheap next to that I/O, so this path is largely deprecated. It is provided
for completeness / reproducing the SpecForge offline recipe and for the niche
case of repeatedly re-training a draft on a fixed, bounded dataset. Prefer the
online recipe (no cached_target_path) unless you specifically need this.
Only EAGLE-3 is supported. EAGLE-1/2 supervise on the full-vocab target distribution (no draft-vocab compression), which is ~0.5 GB per sample — not worth caching — so they keep the online path only.
Typical usage (single device; large MoE targets that need sharding must use the online path instead):
python -m nemo_automodel.components.speculative.precompute_eagle3
—target-model meta-llama/Llama-3.1-8B-Instruct
—input-data Aeala/ShareGPT_Vicuna_unfiltered
—output-dir /data/eagle3_cache/sharegpt_llama31
—seq-length 2048 —draft-vocab-size 8192
—batch-size 4 —shard-size 256 —dtype bf16
Then point the recipe at it: recipe_args.cached_target_path: /data/eagle3_cache/....
Module Contents
Functions
Data
API
Turn one target-model batch into the per-sample tensors the trainer caches.
Reuses _compute_target_distribution — the exact function the online
trainer calls — so the cached target_probs / position_mask are
numerically identical to the live path. Float fields are downcast to
cache_dtype; everything is moved to CPU for writing.
Refuse to --resume into a cache produced with a different configuration.
Every manifest field shapes the shard contents or their addressing: the
target_probs columns follow selected_token_ids (which moves with the
dataset, shuffle seed, and draft vocab size), sample order follows the
dataset, and tensor shapes follow seq_length / dtype. Shards from a
mismatched run are indistinguishable by shape alone, so without this check a
resume after changing e.g. --input-data or --shuffle-seed would keep
the old shards and bless them with the new manifest, silently corrupting the
supervision that training reads back.
Load the target, scan the dataset once, and write the sharded cache. Returns an exit code.
Reject invalid CLI values before loading any model.
CLI entry point. Parses argv and returns the process exit code.