nemo_automodel.components.speculative.eagle.backend
nemo_automodel.components.speculative.eagle.backend
Backend abstraction for the EAGLE-3 target model.
The frozen target model is a supervision provider: for every training batch
it produces the auxiliary hidden states and the per-token target distribution
the draft model is trained against. EAGLE-3 training never updates the target,
so it does not have to share the training GPU. This interface lets the recipe
consume the target uniformly whether it runs co-located in-process
(HFEagle3TargetModel) or, in a later change, as a remote inference service
on separate GPUs.
Module Contents
Classes
API
Abstract contract every EAGLE-3 target-model backend implements.
Two supervision encodings are allowed, both consumed directly by
:meth:Eagle3TrainerModule.forward:
- full logits — :attr:
Eagle3TargetBatch.logitscarries the target’s full-vocab logits and the draft-vocab projection happens trainer-side. Cheap when co-located (the tensor never leaves the GPU); impractical to ship over a wire because it is full-vocab sized. - precomputed — :attr:
Eagle3TargetBatch.target_probsand :attr:Eagle3TargetBatch.position_maskcarry the already-projected draft-vocab distribution, so a backend that computes them itself (e.g. a remote server) only has to transfer draft-vocab-sized tensors.
A backend returns exactly one of the two encodings from
:meth:generate_batch; the recipe forwards whichever is present.
Whether :meth:generate_batch_async is implemented (prefetch-capable).
Release backend resources (remote connections, server handles).
Run the target and return the supervision for one training batch.
Submit an asynchronous :meth:generate_batch for prefetch pipelining.
Only backends that overlap target inference with draft training
implement this; the default signals a synchronous backend so callers
fall back to :meth:generate_batch.
Return the target input-embedding module (used to seed the draft).
Provide the draft-vocab mapping needed to precompute supervision.
Co-located backends keep the mapping on the trainer module and derive
the distribution there, so the default is a no-op. A backend that
computes target_probs itself overrides this to receive the mapping.