nemo_automodel.components.speculative.dspark.target
nemo_automodel.components.speculative.dspark.target
Target-model wrapper for DSpark training (online hidden-state capture).
DSpark feeds the draft two things from the frozen target: the concatenation of a
configured set of decoder-layer hidden states (the draft fc context), and the
final post-norm hidden state (the input the target’s lm_head consumes, used
by the TV / confidence losses). Both are captured in a single forward pass via
forward hooks, mirroring the DFlash target wrapper.
Module Contents
Classes
Data
API
Target-model features needed by the DSpark trainer.
Capture intermediate + final hidden states from a frozen HF causal LM.
A forward hook on decoder layer i captures hidden_states[i + 1] (the
HuggingFace output_hidden_states offset-1 convention); a hook on the final
norm captures the post-norm last hidden state.
Return the final norm module whose output feeds lm_head.
Return decoder layers as an ordered, integer-indexable list.
Return the base transformer module that owns layers and norm.
Handles the common nestings: a plain causal LM (model.model), a
decoder-only base (model), and multimodal targets whose text stack is
under language_model (e.g. Gemma4: model.model.language_model).
Run the target model once and capture the DSpark context + last hidden state.
Features follow common.extract_context_feature exactly: -1 is the
embedding output, the final layer is the post-norm hidden state (HF
output_hidden_states[num_layers]), and any other id is that decoder
layer’s output. The final-norm output is also returned as the last hidden
state for the TV / confidence losses.
Return the target model input embeddings.
Return the target model output embeddings (lm_head).