nemo_automodel.components.datasets.llm.eagle3_cache
nemo_automodel.components.datasets.llm.eagle3_cache
On-disk format + reader for the EAGLE-3 offline target-output cache.
This is the SpecForge “offline” training data path: the frozen target model’s per-token supervision (auxiliary hidden states + the draft-vocab target distribution) is precomputed once and stored on disk, so draft training reads it back instead of re-running the (large, frozen) target every step.
It is extremely disk-intensive — on the order of tens of MB per sample for
an 8B target (aux_hidden_states is 3 * target_hidden_size wide), i.e.
multiple TB for a large dataset — and is largely superseded by online training,
where the target forward is cheap relative to the cache I/O. It is kept for
completeness / reproducibility of the SpecForge offline recipe; prefer the online
path unless you are re-training repeatedly on a fixed, bounded dataset.
This module owns the format (so the producer in
components/speculative/precompute_eagle3.py and the training-time reader
agree on one schema):
<cache_dir>/manifest.json— run config + theselected_token_idsused to build the draft vocabulary (the recipe reuses these instead of rescanning).<cache_dir>/shard-000000.safetensors— one shard holds a contiguous block of samples, each field stacked along dim 0:input_ids[n,S],attention_mask[n,S],loss_mask[n,S](int64),aux_hidden_states[n,S,3H],target_probs[n,S,draft_vocab](float),position_mask[n,S,1](bool).
Each CachedEagle3Dataset item is exactly the keyword arguments
Eagle3TrainerModule.forward consumes on its precomputed-distribution path.
Module Contents
Classes
Functions
Data
API
Bases: Dataset
Reads the EAGLE-3 offline cache; each item is one sample’s trainer inputs.
Shards are opened lazily with safetensors.safe_open (memory-mapped) and
sliced per sample, so the full cache is never loaded into memory at once.
Handles are reopened per worker after a DataLoader fork.
Run write_fn against a sibling .tmp path, then os.replace it into place.
A crash mid-write never leaves a half-written file a later run would load.
Stack per-sample cache dicts into a batch.
Return (save_file, safe_open) or raise a clear error if safetensors is missing.
Build a dataloader over a precomputed EAGLE-3 cache directory.
Return the set of shard indices already present in cache_dir.
Return the manifest path inside cache_dir.
Load the cache manifest, raising if it is missing or the wrong format version.
Load the target input-embedding table written by write_target_embeddings.
Return the path of shard shard_index inside cache_dir.
Persist the cache manifest atomically (.tmp + os.replace).
Write one shard atomically. samples maps each CACHE_KEYS field to a stacked tensor.
Persist the target input-embedding table the draft initializes from.
The offline training path never loads the target model, but the draft’s
embed_tokens must still be seeded from the target’s embeddings (EAGLE-3
concatenates token embeddings with the carried hidden state), so the
producer stores them once alongside the cache.