nemo_automodel.components.speculative.dspark.target

View as Markdown

Target-model wrapper for DSpark training (online hidden-state capture).

DSpark feeds the draft two things from the frozen target: the concatenation of a configured set of decoder-layer hidden states (the draft fc context), and the final post-norm hidden state (the input the target’s lm_head consumes, used by the TV / confidence losses). Both are captured in a single forward pass via forward hooks, mirroring the DFlash target wrapper.

Module Contents

Classes

NameDescription
DSparkTargetBatchTarget-model features needed by the DSpark trainer.
HFDSparkTargetModelCapture intermediate + final hidden states from a frozen HF causal LM.

Data

__all__

API

class nemo_automodel.components.speculative.dspark.target.DSparkTargetBatch(
target_hidden_states: torch.Tensor,
target_last_hidden_states: torch.Tensor,
input_ids: torch.Tensor,
loss_mask: torch.Tensor
)
Dataclass

Target-model features needed by the DSpark trainer.

input_ids
Tensor
loss_mask
Tensor
target_hidden_states
Tensor
target_last_hidden_states
Tensor
class nemo_automodel.components.speculative.dspark.target.HFDSparkTargetModel(
model: torch.nn.Module,
target_layer_ids: typing.Sequence[int]
)

Capture intermediate + final hidden states from a frozen HF causal LM.

A forward hook on decoder layer i captures hidden_states[i + 1] (the HuggingFace output_hidden_states offset-1 convention); a hook on the final norm captures the post-norm last hidden state.

_num_layers
= len(self._get_transformer_layers())
model
= model.eval()
target_layer_ids
nemo_automodel.components.speculative.dspark.target.HFDSparkTargetModel._get_final_norm() -> torch.nn.Module

Return the final norm module whose output feeds lm_head.

nemo_automodel.components.speculative.dspark.target.HFDSparkTargetModel._get_transformer_layers() -> list[torch.nn.Module]

Return decoder layers as an ordered, integer-indexable list.

nemo_automodel.components.speculative.dspark.target.HFDSparkTargetModel._inner_model() -> torch.nn.Module

Return the base transformer module that owns layers and norm.

Handles the common nestings: a plain causal LM (model.model), a decoder-only base (model), and multimodal targets whose text stack is under language_model (e.g. Gemma4: model.model.language_model).

nemo_automodel.components.speculative.dspark.target.HFDSparkTargetModel.generate_batch(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor
) -> nemo_automodel.components.speculative.dspark.target.DSparkTargetBatch

Run the target model once and capture the DSpark context + last hidden state.

Features follow common.extract_context_feature exactly: -1 is the embedding output, the final layer is the post-norm hidden state (HF output_hidden_states[num_layers]), and any other id is that decoder layer’s output. The final-norm output is also returned as the last hidden state for the TV / confidence losses.

nemo_automodel.components.speculative.dspark.target.HFDSparkTargetModel.get_input_embeddings() -> torch.nn.Embedding

Return the target model input embeddings.

nemo_automodel.components.speculative.dspark.target.HFDSparkTargetModel.get_output_embeddings() -> torch.nn.Module

Return the target model output embeddings (lm_head).

nemo_automodel.components.speculative.dspark.target.__all__ = ['HFDSparkTargetModel', 'DSparkTargetBatch']