nemo_automodel.components.speculative.dspark.common

View as Markdown

Module Contents

Classes

NameDescription
AcceptRatePredictor-
DSparkForwardOutputOutputs for one DSpark training forward.

Functions

Data

__all__

API

class nemo_automodel.components.speculative.dspark.common.AcceptRatePredictor(
input_dim: int
)

Bases: Module

proj
= nn.Linear(int(input_dim), 1)
nemo_automodel.components.speculative.dspark.common.AcceptRatePredictor.forward(
features
)
class nemo_automodel.components.speculative.dspark.common.DSparkForwardOutput(
draft_logits: torch.Tensor,
target_ids: torch.Tensor,
eval_mask: torch.Tensor,
block_keep_mask: torch.Tensor,
confidence_pred: typing.Optional[torch.Tensor] = None,
aligned_target_logits: typing.Optional[torch.Tensor] = None
)
Dataclass

Outputs for one DSpark training forward.

The sampler keeps anchors whose first draft target is enabled by loss_mask. Later slots are supervised only while they remain inside seq_len and form a contiguous enabled prefix. Dummy anchors can still appear when a sample has too few valid anchors; they are masked out by block_keep_mask and eval_mask.

aligned_target_logits
Optional[Tensor] = None
block_keep_mask
Tensor
confidence_pred
Optional[Tensor] = None
draft_logits
Tensor
eval_mask
Tensor
target_ids
Tensor
nemo_automodel.components.speculative.dspark.common.build_anchor_candidate_mask(
seq_len: int,
loss_mask: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.common.build_eval_mask(
seq_len: int,
loss_mask: torch.Tensor,
label_indices: torch.Tensor,
safe_label_indices: torch.Tensor,
block_keep_mask: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.common.create_noise_embed(
embed_tokens: torch.nn.Module,
input_ids: torch.Tensor,
anchor_positions: torch.Tensor,
block_keep_mask: torch.Tensor,
mask_token_id: int,
block_size: int
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.common.create_position_ids(
anchor_positions: torch.Tensor,
block_size: int
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.common.extract_context_feature(
hidden_states,
layer_ids
)
nemo_automodel.components.speculative.dspark.common.sample_anchor_positions(
seq_len: int,
loss_mask: torch.Tensor,
num_anchors: int,
device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.speculative.dspark.common.validate_target_layer_ids(
layer_ids,
num_target_layers: int
)
nemo_automodel.components.speculative.dspark.common.__all__ = ['DSparkForwardOutput', 'AcceptRatePredictor', 'extract_context_feature', 'valid...