nemo_automodel.components.datasets.llm.eagle3_cache

View as Markdown

On-disk format + reader for the EAGLE-3 offline target-output cache.

This is the SpecForge “offline” training data path: the frozen target model’s per-token supervision (auxiliary hidden states + the draft-vocab target distribution) is precomputed once and stored on disk, so draft training reads it back instead of re-running the (large, frozen) target every step.

It is extremely disk-intensive — on the order of tens of MB per sample for an 8B target (aux_hidden_states is 3 * target_hidden_size wide), i.e. multiple TB for a large dataset — and is largely superseded by online training, where the target forward is cheap relative to the cache I/O. It is kept for completeness / reproducibility of the SpecForge offline recipe; prefer the online path unless you are re-training repeatedly on a fixed, bounded dataset.

This module owns the format (so the producer in components/speculative/precompute_eagle3.py and the training-time reader agree on one schema):

  • <cache_dir>/manifest.json — run config + the selected_token_ids used to build the draft vocabulary (the recipe reuses these instead of rescanning).
  • <cache_dir>/shard-000000.safetensors — one shard holds a contiguous block of samples, each field stacked along dim 0: input_ids[n,S], attention_mask[n,S], loss_mask[n,S] (int64), aux_hidden_states[n,S,3H], target_probs[n,S,draft_vocab] (float), position_mask[n,S,1] (bool).

Each CachedEagle3Dataset item is exactly the keyword arguments Eagle3TrainerModule.forward consumes on its precomputed-distribution path.

Module Contents

Classes

NameDescription
CachedEagle3DatasetReads the EAGLE-3 offline cache; each item is one sample’s trainer inputs.

Functions

NameDescription
_atomic_writeRun write_fn against a sibling .tmp path, then os.replace it into place.
_collate_cachedStack per-sample cache dicts into a batch.
_load_safetensorsReturn (save_file, safe_open) or raise a clear error if safetensors is missing.
build_cached_eagle3_dataloaderBuild a dataloader over a precomputed EAGLE-3 cache directory.
existing_shard_indicesReturn the set of shard indices already present in cache_dir.
manifest_pathReturn the manifest path inside cache_dir.
read_manifestLoad the cache manifest, raising if it is missing or the wrong format version.
read_target_embeddingsLoad the target input-embedding table written by write_target_embeddings.
shard_pathReturn the path of shard shard_index inside cache_dir.
write_manifestPersist the cache manifest atomically (.tmp + os.replace).
write_shardWrite one shard atomically. samples maps each CACHE_KEYS field to a stacked tensor.
write_target_embeddingsPersist the target input-embedding table the draft initializes from.

Data

CACHE_KEYS

DTYPE_MAP

_BOOL_KEYS

_EMBEDDINGS_NAME

_FLOAT_KEYS

_FORMAT_VERSION

_INT_KEYS

_MANIFEST_NAME

_SHARD_RE

API

class nemo_automodel.components.datasets.llm.eagle3_cache.CachedEagle3Dataset(
cache_dir: str
)

Bases: Dataset

Reads the EAGLE-3 offline cache; each item is one sample’s trainer inputs.

Shards are opened lazily with safetensors.safe_open (memory-mapped) and sliced per sample, so the full cache is never loaded into memory at once. Handles are reopened per worker after a DataLoader fork.

_open_handles
dict[int, Any] = {}
manifest
= read_manifest(cache_dir)
num_samples
= int(self.manifest['num_samples'])
shard_size
= int(self.manifest['shard_size'])
nemo_automodel.components.datasets.llm.eagle3_cache.CachedEagle3Dataset.__getitem__(
index: int
) -> dict[str, torch.Tensor]
nemo_automodel.components.datasets.llm.eagle3_cache.CachedEagle3Dataset.__len__() -> int
nemo_automodel.components.datasets.llm.eagle3_cache.CachedEagle3Dataset._handle(
shard_index: int
)
nemo_automodel.components.datasets.llm.eagle3_cache._atomic_write(
path: str,
write_fn: typing.Callable[[str], None]
) -> str

Run write_fn against a sibling .tmp path, then os.replace it into place.

A crash mid-write never leaves a half-written file a later run would load.

nemo_automodel.components.datasets.llm.eagle3_cache._collate_cached(
features: list[dict[str, torch.Tensor]]
) -> dict[str, torch.Tensor]

Stack per-sample cache dicts into a batch.

nemo_automodel.components.datasets.llm.eagle3_cache._load_safetensors()

Return (save_file, safe_open) or raise a clear error if safetensors is missing.

nemo_automodel.components.datasets.llm.eagle3_cache.build_cached_eagle3_dataloader(
cache_dir: str,
batch_size: int,
shuffle: bool,
num_workers: int = 0,
distributed: bool = False
) -> torch.utils.data.DataLoader

Build a dataloader over a precomputed EAGLE-3 cache directory.

nemo_automodel.components.datasets.llm.eagle3_cache.existing_shard_indices(
cache_dir: str
) -> set[int]

Return the set of shard indices already present in cache_dir.

nemo_automodel.components.datasets.llm.eagle3_cache.manifest_path(
cache_dir: str
) -> str

Return the manifest path inside cache_dir.

nemo_automodel.components.datasets.llm.eagle3_cache.read_manifest(
cache_dir: str
) -> dict[str, typing.Any]

Load the cache manifest, raising if it is missing or the wrong format version.

nemo_automodel.components.datasets.llm.eagle3_cache.read_target_embeddings(
cache_dir: str
) -> torch.Tensor

Load the target input-embedding table written by write_target_embeddings.

nemo_automodel.components.datasets.llm.eagle3_cache.shard_path(
cache_dir: str,
shard_index: int
) -> str

Return the path of shard shard_index inside cache_dir.

nemo_automodel.components.datasets.llm.eagle3_cache.write_manifest(
cache_dir: str,
manifest: dict[str, typing.Any]
) -> str

Persist the cache manifest atomically (.tmp + os.replace).

nemo_automodel.components.datasets.llm.eagle3_cache.write_shard(
cache_dir: str,
shard_index: int,
samples: dict[str, torch.Tensor]
) -> str

Write one shard atomically. samples maps each CACHE_KEYS field to a stacked tensor.

nemo_automodel.components.datasets.llm.eagle3_cache.write_target_embeddings(
cache_dir: str,
weight: torch.Tensor
) -> str

Persist the target input-embedding table the draft initializes from.

The offline training path never loads the target model, but the draft’s embed_tokens must still be seeded from the target’s embeddings (EAGLE-3 concatenates token embeddings with the carried hidden state), so the producer stores them once alongside the cache.

nemo_automodel.components.datasets.llm.eagle3_cache.CACHE_KEYS = _FLOAT_KEYS + _INT_KEYS + _BOOL_KEYS
nemo_automodel.components.datasets.llm.eagle3_cache.DTYPE_MAP = {'bf16': torch.bfloat16, 'fp16': torch.float16, 'fp32': torch.float32}
nemo_automodel.components.datasets.llm.eagle3_cache._BOOL_KEYS = ('position_mask',)
nemo_automodel.components.datasets.llm.eagle3_cache._EMBEDDINGS_NAME = 'target_embeddings.safetensors'
nemo_automodel.components.datasets.llm.eagle3_cache._FLOAT_KEYS = ('aux_hidden_states', 'target_probs')
nemo_automodel.components.datasets.llm.eagle3_cache._FORMAT_VERSION = 1
nemo_automodel.components.datasets.llm.eagle3_cache._INT_KEYS = ('input_ids', 'attention_mask', 'loss_mask')
nemo_automodel.components.datasets.llm.eagle3_cache._MANIFEST_NAME = 'manifest.json'
nemo_automodel.components.datasets.llm.eagle3_cache._SHARD_RE = re.compile('^shard-(\\d{6})\\.safetensors$')