nemo_automodel.components.speculative.dspark.draft_gemma4

View as Markdown

Module Contents

Classes

Data

__all__

API

class nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkAttention(
config,
layer_idx: int
)

Bases: Module

attention_dropout
= float(config.attention_dropout)
head_dim
= int(config.global_head_dim)
k_norm
k_proj
layer_idx
= int(layer_idx)
num_attention_heads
= int(config.num_attention_heads)
num_key_value_groups
num_key_value_heads
= int(config.num_global_key_value_heads)
o_proj
q_norm
q_proj
scaling
= 1.0
use_alternative_attention
= bool(config.attention_k_eq_v)
v_norm
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkAttention._repeat_kv(
hidden_states: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkAttention.forward(
hidden_states: torch.Tensor,
target_hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: typing.Optional[torch.Tensor],
past_key_values: typing.Optional[transformers.cache_utils.Cache] = None,
cache_position: typing.Optional[torch.LongTensor] = None,
kwargs = {}
) -> tuple[torch.Tensor, typing.Optional[torch.Tensor]]
class nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkDecoderLayer(
config,
layer_idx: int
)

Bases: GradientCheckpointingLayer

hidden_size
= config.hidden_size
input_layernorm
mlp
= Gemma4TextMLP(config, layer_idx)
post_attention_layernorm
post_feedforward_layernorm
pre_feedforward_layernorm
self_attn
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkDecoderLayer.forward(
target_hidden_states: typing.Optional[torch.Tensor] = None,
hidden_states: typing.Optional[torch.Tensor] = None,
attention_mask: typing.Optional[torch.Tensor] = None,
position_ids: typing.Optional[torch.LongTensor] = None,
past_key_value: typing.Optional[transformers.cache_utils.Cache] = None,
output_attentions: typing.Optional[bool] = False,
use_cache: typing.Optional[bool] = False,
cache_position: typing.Optional[torch.LongTensor] = None,
position_embeddings: typing.Optional[tuple[torch.Tensor, torch.Tensor]] = None,
kwargs = {}
) -> torch.Tensor
class nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel(
config
)

Bases: Gemma4PreTrainedModel

_no_split_modules
= ['Gemma4DSparkDecoderLayer']
base_model_prefix
= 'model'
block_size
= int(config.block_size)
embed_tokens
enable_confidence_head
= bool(config.enable_confidence_head)
fc
hidden_norm
layers
lm_head
markov_head
= build_markov_head(config)
mask_token_id
= config.mask_token_id
norm
num_anchors
= int(config.num_anchors)
rotary_emb
target_layer_ids
= config.target_layer_ids
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel._forward_backbone(
position_ids: torch.LongTensor,
attention_mask: typing.Optional[torch.Tensor] = None,
noise_embedding: typing.Optional[torch.Tensor] = None,
target_hidden_states: typing.Optional[torch.Tensor] = None,
past_key_values: typing.Optional[transformers.cache_utils.Cache] = None,
use_cache: bool = False,
kwargs = {}
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.compute_logits(
hidden_states: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.forward(
input_ids: torch.Tensor,
target_hidden_states: torch.Tensor,
loss_mask: torch.Tensor,
target_last_hidden_states: typing.Optional[torch.Tensor] = None
) -> nemo_automodel.components.speculative.dspark.common.DSparkForwardOutput
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.initialize_embeddings_and_head(
embed_tokens: torch.nn.Module,
lm_head: torch.nn.Module,
freeze: bool = True
)
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.predict_confidence_step(
hidden_states: torch.Tensor,
prev_token_ids: typing.Optional[torch.Tensor] = None
) -> typing.Optional[torch.Tensor]
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.sample_draft_token_step(
base_logits: torch.Tensor,
prev_token_ids: torch.Tensor,
temperature: float = 0.0,
hidden_states: typing.Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.sample_draft_tokens(
base_logits: torch.Tensor,
first_prev_token_ids: torch.Tensor,
temperature: float = 0.0,
hidden_states: typing.Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.set_embedding_head_trainable(
trainable: bool
)
nemo_automodel.components.speculative.dspark.draft_gemma4.__all__ = ['Gemma4DSparkModel']