nemo_automodel.components.speculative.dspark.markov_head

View as Markdown

Module Contents

Classes

NameDescription
GatedMarkovHead-
RNNHeadRNN-based head that maintains recurrent state across positions within a block.
VanillaMarkov-

Functions

NameDescription
build_markov_head-

Data

__all__

API

class nemo_automodel.components.speculative.dspark.markov_head.GatedMarkovHead(
vocab_size: int,
markov_rank: int,
hidden_size: int
)

Bases: VanillaMarkov

gate_proj
= nn.Linear(hidden_size + markov_rank, markov_rank)
markov_head_type
= 'gated'
nemo_automodel.components.speculative.dspark.markov_head.GatedMarkovHead.compute_gate(
token_ids: torch.Tensor,
hidden_states: typing.Optional[torch.Tensor]
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.markov_head.GatedMarkovHead.compute_step_bias(
token_ids: torch.Tensor,
hidden_states: typing.Optional[torch.Tensor]
) -> torch.Tensor
class nemo_automodel.components.speculative.dspark.markov_head.RNNHead(
vocab_size: int,
markov_rank: int,
hidden_size: int
)

Bases: VanillaMarkov

RNN-based head that maintains recurrent state across positions within a block.

Unlike the memoryless Markov heads, position k can access the full prefix history x_{<k} through a GRU-like recurrent state.

joint_proj
markov_head_type
= 'rnn'
nemo_automodel.components.speculative.dspark.markov_head.RNNHead._rnn_step(
state: torch.Tensor,
prev_embeddings: torch.Tensor,
hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]

Single RNN step.

Parameters:

state
torch.Tensor

[*, r] previous recurrent state

prev_embeddings
torch.Tensor

[*, r] W1[x_{k-1}]

hidden_states
torch.Tensor

[*, d] backbone hidden at step k

Returns: torch.Tensor

[*, r]

nemo_automodel.components.speculative.dspark.markov_head.RNNHead.apply_block_logits(
base_logits: torch.Tensor,
token_ids: torch.Tensor,
hidden_states: typing.Optional[torch.Tensor]
) -> torch.Tensor

Apply RNN bias during training (teacher-forced, unrolled over block_size).

Parameters:

base_logits
torch.Tensor

[B, num_blocks, block_size, V]

token_ids
torch.Tensor

[B, num_blocks, block_size] - prev token ids for each position

hidden_states
Optional[torch.Tensor]

[B, num_blocks, block_size, d]

nemo_automodel.components.speculative.dspark.markov_head.RNNHead.compute_step_bias(
token_ids: torch.Tensor,
hidden_states: typing.Optional[torch.Tensor]
) -> torch.Tensor

Stateless single-step bias (state initialized to zero).

This is used for compatibility but does NOT carry state across steps. For full RNN behavior, use apply_block_logits or sample_block_tokens.

nemo_automodel.components.speculative.dspark.markov_head.RNNHead.sample_block_tokens(
base_logits: torch.Tensor,
first_prev_token_ids: torch.Tensor,
hidden_states: typing.Optional[torch.Tensor],
temperature: float = 0.0
) -> tuple[torch.Tensor, torch.Tensor]

Autoregressive sampling with RNN state.

Parameters:

base_logits
torch.Tensor

[batch, proposal_len, vocab]

first_prev_token_ids
torch.Tensor

[batch] - token preceding the first draft position

hidden_states
Optional[torch.Tensor]

[batch, proposal_len, d]

temperature
floatDefaults to 0.0

sampling temperature

Returns: torch.Tensor

[batch, proposal_len]

class nemo_automodel.components.speculative.dspark.markov_head.VanillaMarkov(
vocab_size: int,
markov_rank: int
)

Bases: Module

markov_head_type
= 'vanilla'
markov_rank
= int(markov_rank)
markov_w1
= nn.Embedding(self.vocab_size, self.markov_rank)
markov_w2
vocab_size
= int(vocab_size)
nemo_automodel.components.speculative.dspark.markov_head.VanillaMarkov.apply_block_logits(
base_logits: torch.Tensor,
token_ids: torch.Tensor,
hidden_states: typing.Optional[torch.Tensor]
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.markov_head.VanillaMarkov.apply_step_logits(
logits: torch.Tensor,
token_ids: torch.Tensor,
hidden_states: typing.Optional[torch.Tensor]
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.markov_head.VanillaMarkov.compute_step_bias(
token_ids: torch.Tensor,
hidden_states: typing.Optional[torch.Tensor]
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.markov_head.VanillaMarkov.get_prev_embeddings(
token_ids: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.markov_head.VanillaMarkov.project_bias(
latent_states: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.markov_head.VanillaMarkov.sample_block_tokens(
base_logits: torch.Tensor,
first_prev_token_ids: torch.Tensor,
hidden_states: typing.Optional[torch.Tensor],
temperature: float = 0.0
) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.speculative.dspark.markov_head.build_markov_head(
config
) -> torch.nn.Module | None
nemo_automodel.components.speculative.dspark.markov_head.__all__ = ['VanillaMarkov', 'GatedMarkovHead', 'RNNHead', 'build_markov_head']