nemo_automodel.components.speculative.regenerate

View as Markdown

Regenerate dataset answers with the EAGLE target model.

EAGLE drafters learn best when the supervised assistant turn is produced by the same model that will serve as the inference target. Many public chat datasets were generated by other models, so the assistant tokens they contain are off-distribution for the drafter. This script takes such a dataset, strips the trailing assistant turn from each sample, replays the remaining [system, user, ...] context against a target model running behind an OpenAI-compatible SGLang server, and writes a new dataset whose messages column ends with a freshly-generated assistant turn.

The output parquet files have the same messages column shape that ChatDataset (used by build_eagle3_dataloader) consumes, so the regenerated directory can be plugged directly into train_data_path in the EAGLE-3 recipe.

Typical usage:

1. Spin up SGLang serving the target model in another shell:

python -m sglang.launch_server
—model-path meta-llama/Llama-3.1-8B-Instruct —port 30000

2. Regenerate answers:

python -m nemo_automodel.components.speculative.regenerate
—input-data Aeala/ShareGPT_Vicuna_unfiltered
—output-dir ./regenerated/sharegpt_llama31_8b
—target-server http://localhost:30000/v1
—model meta-llama/Llama-3.1-8B-Instruct
—concurrency 64 —shard-size 1000

The script is resumable: re-running with the same --output-dir --resume skips any shards that are already on disk, and verifies via a manifest that the input/model/sharding configuration matches the earlier run.

Module Contents

Classes

NameDescription
GenerationConfigSampling parameters forwarded to the SGLang chat completion endpoint.

Functions

NameDescription
_build_manifestReturn the regeneration settings that must stay stable across resume.
_build_parser-
_chat_completionPOST payload to url and return the assistant message dict, with bounded retries.
_ensure_manifest_compatibleGuard --resume against silently mixing shards from different runs.
_existing_shard_indicesReturn the set of shard indices already present in output_dir.
_extract_prompt_messagesReturn messages truncated so its tail is not an assistant turn.
_import_aiohttp-
_import_pyarrow_table-
_iter_samplesYield rows’ messages_column from an HF dataset or a list of dicts.
_manifest_pathReturn the manifest path inside output_dir.
_process_shardRun a single shard’s prompts through the target server with bounded concurrency.
_regenerate_oneCall the target server once and return prompt + [assistant].
_runAsync driver: load dataset, regenerate, write shards. Returns a process exit code.
_validate_argsReject invalid CLI values before any network or disk work starts.
_write_manifestPersist the current regeneration config for future --resume checks.
_write_shardWrite a shard atomically (.tmp then os.replace) so partial writes never linger.
mainCLI entry point. Parses argv and returns the process exit code.

Data

_AIOHTTP_INSTALL_HINT

_MANIFEST_NAME

_PYARROW_INSTALL_HINT

_SHARD_NAME_RE

logger

API

class nemo_automodel.components.speculative.regenerate.GenerationConfig(
model: str,
max_new_tokens: int,
temperature: float,
top_p: float,
reasoning: str = 'none'
)
Dataclass

Sampling parameters forwarded to the SGLang chat completion endpoint.

max_new_tokens
int
model
str
reasoning
str = 'none'
temperature
float
top_p
float
nemo_automodel.components.speculative.regenerate._build_manifest(
args: argparse.Namespace
) -> dict[str, typing.Any]

Return the regeneration settings that must stay stable across resume.

Fields that change the content of the output dataset are included. Fields that only affect throughput / reliability (concurrency, timeout_s, max_retries) are intentionally omitted so a user can re-resume with different operational knobs. output_dir is also omitted: the manifest lives inside output_dir, so encoding it here would only break resume after a directory rename.

nemo_automodel.components.speculative.regenerate._build_parser() -> argparse.ArgumentParser
nemo_automodel.components.speculative.regenerate._chat_completion(
session,
url: str,
payload: dict[str, typing.Any],
timeout_s: float,
max_retries: int
) -> dict[str, typing.Any]
async

POST payload to url and return the assistant message dict, with bounded retries.

nemo_automodel.components.speculative.regenerate._ensure_manifest_compatible(
output_dir: pathlib.Path,
manifest: dict[str, typing.Any],
resume: bool,
existing_shards: set[int]
) -> None

Guard --resume against silently mixing shards from different runs.

Also refuses to start a fresh run that would silently clobber existing shards: if the output directory already contains shard files and the user did not pass --resume, raise so they make an explicit choice (either delete the directory or pass --resume).

nemo_automodel.components.speculative.regenerate._existing_shard_indices(
output_dir: pathlib.Path
) -> set[int]

Return the set of shard indices already present in output_dir.

nemo_automodel.components.speculative.regenerate._extract_prompt_messages(
messages: list[dict[str, typing.Any]]
) -> list[dict[str, typing.Any]] | None

Return messages truncated so its tail is not an assistant turn.

EAGLE-3 supervision needs an assistant turn produced by the target model. The strategy here mirrors SpecForge’s offline regeneration: keep every leading system / user / tool turn (including any intermediate user<->assistant rounds), but drop the trailing assistant turn so the target can produce a fresh one.

Returns None if the sample has no valid prompt context (e.g. it is empty, or starts with an assistant turn that gets dropped, leaving nothing). Callers should skip such samples.

nemo_automodel.components.speculative.regenerate._import_aiohttp()
nemo_automodel.components.speculative.regenerate._import_pyarrow_table()
nemo_automodel.components.speculative.regenerate._iter_samples(
dataset: typing.Any,
messages_column: str
) -> typing.Any

Yield rows’ messages_column from an HF dataset or a list of dicts.

nemo_automodel.components.speculative.regenerate._manifest_path(
output_dir: pathlib.Path
) -> pathlib.Path

Return the manifest path inside output_dir.

nemo_automodel.components.speculative.regenerate._process_shard(
session,
url: str,
shard_samples: list[tuple[int, list[dict[str, typing.Any]], list[dict[str, typing.Any]]]],
gen_cfg: nemo_automodel.components.speculative.regenerate.GenerationConfig,
concurrency: int,
timeout_s: float,
max_retries: int
) -> list[dict[str, typing.Any]]
async

Run a single shard’s prompts through the target server with bounded concurrency.

shard_samples items are (global_index, original_messages, prompt_messages); only the prompt is sent to the server, but both are kept around so the written rows can preserve the original for traceability.

nemo_automodel.components.speculative.regenerate._regenerate_one(
session,
url: str,
prompt: list[dict[str, typing.Any]],
gen_cfg: nemo_automodel.components.speculative.regenerate.GenerationConfig,
timeout_s: float,
max_retries: int
) -> list[dict[str, typing.Any]]
async

Call the target server once and return prompt + [assistant].

nemo_automodel.components.speculative.regenerate._run(
args: argparse.Namespace
) -> int
async

Async driver: load dataset, regenerate, write shards. Returns a process exit code.

nemo_automodel.components.speculative.regenerate._validate_args(
args: argparse.Namespace
) -> None

Reject invalid CLI values before any network or disk work starts.

nemo_automodel.components.speculative.regenerate._write_manifest(
output_dir: pathlib.Path,
manifest: dict[str, typing.Any]
) -> pathlib.Path

Persist the current regeneration config for future --resume checks.

nemo_automodel.components.speculative.regenerate._write_shard(
output_dir: pathlib.Path,
shard_index: int,
rows: list[dict[str, typing.Any]]
) -> pathlib.Path

Write a shard atomically (.tmp then os.replace) so partial writes never linger.

nemo_automodel.components.speculative.regenerate.main(
argv: list[str] | None = None
) -> int

CLI entry point. Parses argv and returns the process exit code.

nemo_automodel.components.speculative.regenerate._AIOHTTP_INSTALL_HINT = "aiohttp is required for regenerate.py. It is normally pulled in via the project...
nemo_automodel.components.speculative.regenerate._MANIFEST_NAME = 'manifest.json'
nemo_automodel.components.speculative.regenerate._PYARROW_INSTALL_HINT = 'pyarrow is required to write regenerated shards as parquet. Install it with `uv...
nemo_automodel.components.speculative.regenerate._SHARD_NAME_RE = re.compile('^shard-(\\d{6})\\.parquet$')
nemo_automodel.components.speculative.regenerate.logger = logging.getLogger(__name__)