nemo_automodel.components.speculative.eagle.sglang_runner

View as Markdown

SGLang ModelRunner forward for the EAGLE-3 target (server-side, GPU only).

This module owns every SGLang-internal touch point so the rest of the speculative stack stays SGLang-agnostic and importable without SGLang. It is imported lazily (only from :meth:SGLangEagle3TargetModel.from_pretrained).

Mechanism (mirrors SpecForge’s SGLang backend, which is the only path that returns the supervision tensors directly without a Mooncake transfer layer):

  1. Build a SGLang ModelRunner with enable_return_hidden_states=True.
  2. Wrap the model’s LogitsProcessor so a single extend forward returns all-position full-vocab logits (stock SGLang only keeps the last position) alongside the three concatenated EAGLE-3 auxiliary hidden states.
  3. Run one extend per request and stack the per-row results into batched [batch, seq, *] tensors for :class:SGLangEagle3TargetModel.

Unlike SpecForge (which embeds the target inside the training job and reuses the trainer’s TP process group), this runs in a standalone server process, so it performs SGLang’s own single-process distributed init rather than reusing an external group, and it drops SpecForge’s shard_returns / VLM paths that only matter inside the training loop.

SGLang’s private LogitsProcessor helpers are version-coupled: the calls here track sglang==0.5.9 (SpecForge’s pin). This forward path requires a GPU and SGLang, so it is validated on the training server, not in CPU unit tests; the CPU tests exercise the contract layer in :mod:nemo_automodel.components.speculative.eagle.sglang_target against a fake runner instead.

Module Contents

Classes

NameDescription
SGLangTargetRunnerStandalone SGLang ModelRunner that returns EAGLE-3 supervision tensors.
_Eagle3LogitsOutputCarries the all-position logits + aux hidden states out of the wrapper.

Functions

NameDescription
_wrap_logits_processors_for_eagle3Replace every SGLang LogitsProcessor in model with an EAGLE-3 wrapper.
sglang_dtype_strMap a torch dtype to the string form SGLang’s ServerArgs.dtype expects.

Data

_SGLANG_DTYPE_STRINGS

logger

API

class nemo_automodel.components.speculative.eagle.sglang_runner.SGLangTargetRunner(
model_runner
)

Standalone SGLang ModelRunner that returns EAGLE-3 supervision tensors.

Built via :meth:build; consumed through the engine-agnostic :class:~nemo_automodel.components.speculative.eagle.target_runner.TargetRunner surface (model / set_aux_layers / forward_eagle3 / input_embedding_weight).

model

The loaded nn.Module (exposes .config and .parameters()).

nemo_automodel.components.speculative.eagle.sglang_runner.SGLangTargetRunner._extend(
input_ids: torch.Tensor
) -> tuple[list, list]
nemo_automodel.components.speculative.eagle.sglang_runner.SGLangTargetRunner.build(
model_path: str,
dtype: typing.Optional[torch.dtype] = None,
tp_size: int = 1,
trust_remote_code: bool = False,
sglang_kwargs = {}
) -> 'SGLangTargetRunner'
classmethod

Construct the SGLang ModelRunner for a standalone target server.

sglang_kwargs are forwarded to ServerArgs (e.g. page_size, mem_fraction_static, attention_backend). The constructor mirrors SpecForge’s sglang==0.5.9 usage and is GPU/SGLang-only.

nemo_automodel.components.speculative.eagle.sglang_runner.SGLangTargetRunner.close() -> None

Release the SGLang model runner (best effort).

nemo_automodel.components.speculative.eagle.sglang_runner.SGLangTargetRunner.forward_eagle3(
input_ids: torch.Tensor,
attention_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]

Run one extend per row and stack the per-position logits and aux states.

Returns (logits[batch, seq, vocab], aux[batch, seq, 3 * hidden]), both unshifted; the contract layer applies the EAGLE-3 shift. Sequences must share a length (training batches are padded), so the per-row results stack cleanly.

nemo_automodel.components.speculative.eagle.sglang_runner.SGLangTargetRunner.input_embedding_weight() -> torch.Tensor

Return the target input-embedding weight [vocab, hidden].

nemo_automodel.components.speculative.eagle.sglang_runner.SGLangTargetRunner.set_aux_layers(
aux_layer_ids: typing.Sequence[int]
) -> None

Tell the SGLang model which 3 decoder layers to capture.

class nemo_automodel.components.speculative.eagle.sglang_runner._Eagle3LogitsOutput(
logits: torch.Tensor,
aux_hidden_states: torch.Tensor
)

Carries the all-position logits + aux hidden states out of the wrapper.

nemo_automodel.components.speculative.eagle.sglang_runner._wrap_logits_processors_for_eagle3(
model
) -> None

Replace every SGLang LogitsProcessor in model with an EAGLE-3 wrapper.

The wrapper makes one extend forward return all-position full-vocab logits plus the concatenated auxiliary hidden states, instead of stock SGLang’s last-position-only logits. Ported (simplified, no tensor-parallel sharding) from SpecForge sglang_backend/utils.py for sglang==0.5.9.

nemo_automodel.components.speculative.eagle.sglang_runner.sglang_dtype_str(
dtype: typing.Optional[torch.dtype]
) -> str

Map a torch dtype to the string form SGLang’s ServerArgs.dtype expects.

SGLang compares ServerArgs.dtype against string literals ("auto", "bfloat16", …), so passing a raw torch.dtype silently misses every branch. None means “let SGLang pick” ("auto").

nemo_automodel.components.speculative.eagle.sglang_runner._SGLANG_DTYPE_STRINGS = {torch.float32: 'float32', torch.float16: 'float16', torch.bfloat16: 'bfloat16'}
nemo_automodel.components.speculative.eagle.sglang_runner.logger = logging.getLogger(__name__)