nemo_automodel.components.speculative.eagle.sglang_runner
nemo_automodel.components.speculative.eagle.sglang_runner
SGLang ModelRunner forward for the EAGLE-3 target (server-side, GPU only).
This module owns every SGLang-internal touch point so the rest of the
speculative stack stays SGLang-agnostic and importable without SGLang. It is
imported lazily (only from :meth:SGLangEagle3TargetModel.from_pretrained).
Mechanism (mirrors SpecForge’s SGLang backend, which is the only path that returns the supervision tensors directly without a Mooncake transfer layer):
- Build a SGLang
ModelRunnerwithenable_return_hidden_states=True. - Wrap the model’s
LogitsProcessorso a single extend forward returns all-position full-vocab logits (stock SGLang only keeps the last position) alongside the three concatenated EAGLE-3 auxiliary hidden states. - Run one extend per request and stack the per-row results into batched
[batch, seq, *]tensors for :class:SGLangEagle3TargetModel.
Unlike SpecForge (which embeds the target inside the training job and reuses the
trainer’s TP process group), this runs in a standalone server process, so it
performs SGLang’s own single-process distributed init rather than reusing an
external group, and it drops SpecForge’s shard_returns / VLM paths that only
matter inside the training loop.
SGLang’s private LogitsProcessor helpers are version-coupled: the calls here
track sglang==0.5.9 (SpecForge’s pin). This forward path requires a GPU and
SGLang, so it is validated on the training server, not in CPU unit tests; the
CPU tests exercise the contract layer in
:mod:nemo_automodel.components.speculative.eagle.sglang_target against a fake
runner instead.
Module Contents
Classes
Functions
Data
API
Standalone SGLang ModelRunner that returns EAGLE-3 supervision tensors.
Built via :meth:build; consumed through the engine-agnostic
:class:~nemo_automodel.components.speculative.eagle.target_runner.TargetRunner
surface (model / set_aux_layers / forward_eagle3 /
input_embedding_weight).
The loaded nn.Module (exposes .config and .parameters()).
Construct the SGLang ModelRunner for a standalone target server.
sglang_kwargs are forwarded to ServerArgs (e.g. page_size,
mem_fraction_static, attention_backend). The constructor mirrors
SpecForge’s sglang==0.5.9 usage and is GPU/SGLang-only.
Release the SGLang model runner (best effort).
Run one extend per row and stack the per-position logits and aux states.
Returns (logits[batch, seq, vocab], aux[batch, seq, 3 * hidden]),
both unshifted; the contract layer applies the EAGLE-3 shift. Sequences
must share a length (training batches are padded), so the per-row
results stack cleanly.
Return the target input-embedding weight [vocab, hidden].
Tell the SGLang model which 3 decoder layers to capture.
Carries the all-position logits + aux hidden states out of the wrapper.
Replace every SGLang LogitsProcessor in model with an EAGLE-3 wrapper.
The wrapper makes one extend forward return all-position full-vocab logits
plus the concatenated auxiliary hidden states, instead of stock SGLang’s
last-position-only logits. Ported (simplified, no tensor-parallel sharding)
from SpecForge sglang_backend/utils.py for sglang==0.5.9.
Map a torch dtype to the string form SGLang’s ServerArgs.dtype expects.
SGLang compares ServerArgs.dtype against string literals ("auto",
"bfloat16", …), so passing a raw torch.dtype silently misses every
branch. None means “let SGLang pick” ("auto").