nemo_automodel.components.speculative.precompute_eagle3

View as Markdown

Precompute the EAGLE-3 offline target-output cache (SpecForge “offline” path).

The frozen target model’s per-token supervision is the same on every epoch and every run, yet the online recipe recomputes it each step. This script runs the target once over a dataset and writes the supervision to disk; training then reads it back via cached_target_path instead of loading or running the target at all.

============================== READ THIS ================================= This is the SpecForge offline training path. It is extremely disk intensiveaux_hidden_states alone is 3 * target_hidden_size wide, so an 8B target at seq-len 2048 costs ~80 MB per sample (tens of TB for a large corpus). Modern practice trains online, where the target forward is cheap next to that I/O, so this path is largely deprecated. It is provided for completeness / reproducing the SpecForge offline recipe and for the niche case of repeatedly re-training a draft on a fixed, bounded dataset. Prefer the online recipe (no cached_target_path) unless you specifically need this.

Only EAGLE-3 is supported. EAGLE-1/2 supervise on the full-vocab target distribution (no draft-vocab compression), which is ~0.5 GB per sample — not worth caching — so they keep the online path only.

Typical usage (single device; large MoE targets that need sharding must use the online path instead):

python -m nemo_automodel.components.speculative.precompute_eagle3
—target-model meta-llama/Llama-3.1-8B-Instruct
—input-data Aeala/ShareGPT_Vicuna_unfiltered
—output-dir /data/eagle3_cache/sharegpt_llama31
—seq-length 2048 —draft-vocab-size 8192
—batch-size 4 —shard-size 256 —dtype bf16

Then point the recipe at it: recipe_args.cached_target_path: /data/eagle3_cache/....

Module Contents

Functions

NameDescription
_build_parser-
_compute_batch_cacheTurn one target-model batch into the per-sample tensors the trainer caches.
_ensure_resume_compatibleRefuse to --resume into a cache produced with a different configuration.
_runLoad the target, scan the dataset once, and write the sharded cache. Returns an exit code.
_validate_argsReject invalid CLI values before loading any model.
mainCLI entry point. Parses argv and returns the process exit code.

Data

logger

API

nemo_automodel.components.speculative.precompute_eagle3._build_parser() -> argparse.ArgumentParser
nemo_automodel.components.speculative.precompute_eagle3._compute_batch_cache(
target_batch,
selected_token_ids: torch.Tensor,
selected_token_mask: torch.Tensor,
cache_dtype: torch.dtype
) -> dict[str, torch.Tensor]

Turn one target-model batch into the per-sample tensors the trainer caches.

Reuses _compute_target_distribution — the exact function the online trainer calls — so the cached target_probs / position_mask are numerically identical to the live path. Float fields are downcast to cache_dtype; everything is moved to CPU for writing.

nemo_automodel.components.speculative.precompute_eagle3._ensure_resume_compatible(
cache_dir: str,
manifest: dict[str, typing.Any],
existing_shards: set[int]
) -> None

Refuse to --resume into a cache produced with a different configuration.

Every manifest field shapes the shard contents or their addressing: the target_probs columns follow selected_token_ids (which moves with the dataset, shuffle seed, and draft vocab size), sample order follows the dataset, and tensor shapes follow seq_length / dtype. Shards from a mismatched run are indistinguishable by shape alone, so without this check a resume after changing e.g. --input-data or --shuffle-seed would keep the old shards and bless them with the new manifest, silently corrupting the supervision that training reads back.

nemo_automodel.components.speculative.precompute_eagle3._run(
args: argparse.Namespace
) -> int

Load the target, scan the dataset once, and write the sharded cache. Returns an exit code.

nemo_automodel.components.speculative.precompute_eagle3._validate_args(
args: argparse.Namespace
) -> None

Reject invalid CLI values before loading any model.

nemo_automodel.components.speculative.precompute_eagle3.main(
argv: list[str] | None = None
) -> int

CLI entry point. Parses argv and returns the process exit code.

nemo_automodel.components.speculative.precompute_eagle3.logger = logging.getLogger(__name__)