nemo_automodel.components.speculative.dflash.target

View as Markdown

Target-model wrapper for DFlash training.

Unlike EAGLE-3 (which captures exactly three aux layers and left-shifts the supervision), DFlash captures an arbitrary set of decoder layers, concatenates them along the feature dim, and feeds the result to the draft as context. No shifting is applied — the DFlash block attention mask handles anchor alignment.

Module Contents

Classes

NameDescription
DFlashTargetBatchTarget-model context features needed by the DFlash trainer.
HFDFlashTargetModelCapture a set of decoder-layer hidden states from a frozen HF causal LM.

API

class nemo_automodel.components.speculative.dflash.target.DFlashTargetBatch(
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor
)
Dataclass

Target-model context features needed by the DFlash trainer.

attention_mask
Tensor
hidden_states
Tensor
input_ids
Tensor
loss_mask
Tensor
class nemo_automodel.components.speculative.dflash.target.HFDFlashTargetModel(
model: torch.nn.Module,
target_layer_ids: typing.Sequence[int]
)

Capture a set of decoder-layer hidden states from a frozen HF causal LM.

A forward hook on decoder layer i captures that layer’s output, which in HuggingFace’s output_hidden_states convention is hidden_states[i + 1] — matching SpecForge’s extract_context_feature (offset 1).

model
= model.eval()
target_layer_ids
= self._validate_layer_ids(target_layer_ids)
nemo_automodel.components.speculative.dflash.target.HFDFlashTargetModel._get_transformer_layers() -> list[torch.nn.Module]

Return decoder layers as an ordered, integer-indexable list.

nemo_automodel.components.speculative.dflash.target.HFDFlashTargetModel._validate_layer_ids(
target_layer_ids: typing.Sequence[int]
) -> list[int]
nemo_automodel.components.speculative.dflash.target.HFDFlashTargetModel.generate_batch(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor
) -> nemo_automodel.components.speculative.dflash.target.DFlashTargetBatch

Run the target model and capture the selected layers’ hidden states as context.

nemo_automodel.components.speculative.dflash.target.HFDFlashTargetModel.get_input_embeddings() -> torch.nn.Embedding

Return the target model input embeddings.