nemo_automodel.components.speculative.dflash.target
nemo_automodel.components.speculative.dflash.target
Target-model wrapper for DFlash training.
Unlike EAGLE-3 (which captures exactly three aux layers and left-shifts the supervision), DFlash captures an arbitrary set of decoder layers, concatenates them along the feature dim, and feeds the result to the draft as context. No shifting is applied — the DFlash block attention mask handles anchor alignment.
Module Contents
Classes
API
Target-model context features needed by the DFlash trainer.
Capture a set of decoder-layer hidden states from a frozen HF causal LM.
A forward hook on decoder layer i captures that layer’s output, which in
HuggingFace’s output_hidden_states convention is hidden_states[i + 1]
— matching SpecForge’s extract_context_feature (offset 1).
Return decoder layers as an ordered, integer-indexable list.
Run the target model and capture the selected layers’ hidden states as context.
Return the target model input embeddings.