DevelopmentAPI ReferenceFull Library ReferenceNemo AutomodelNemo AutomodelComponentsSpeculativeDspark
nemo_automodel.components.speculative.dspark.draft_gemma4
nemo_automodel.components.speculative.dspark.draft_gemma4
Module Contents
Classes
| Name | Description |
|---|---|
Gemma4DSparkAttention | - |
Gemma4DSparkDecoderLayer | - |
Gemma4DSparkModel | - |
Data
API
class nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkAttention( config, layer_idx: int )
Bases: Module
attention_dropout
= float(config.attention_dropout)
head_dim
= int(config.global_head_dim)
k_norm
k_proj
layer_idx
= int(layer_idx)
num_attention_heads
= int(config.num_attention_heads)
num_key_value_groups
num_key_value_heads
= int(config.num_global_key_value_heads)
o_proj
q_norm
q_proj
scaling
= 1.0
use_alternative_attention
= bool(config.attention_k_eq_v)
v_norm
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkAttention._repeat_kv( hidden_states: torch.Tensor ) -> torch.Tensor
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkAttention.forward( hidden_states: torch.Tensor, target_hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: typing.Optional[torch.Tensor], past_key_values: typing.Optional[transformers.cache_utils.Cache] = None, cache_position: typing.Optional[torch.LongTensor] = None, kwargs = {} ) -> tuple[torch.Tensor, typing.Optional[torch.Tensor]]
class nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkDecoderLayer( config, layer_idx: int )
Bases: GradientCheckpointingLayer
hidden_size
= config.hidden_size
input_layernorm
mlp
= Gemma4TextMLP(config, layer_idx)
post_attention_layernorm
post_feedforward_layernorm
pre_feedforward_layernorm
self_attn
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkDecoderLayer.forward( target_hidden_states: typing.Optional[torch.Tensor] = None, hidden_states: typing.Optional[torch.Tensor] = None, attention_mask: typing.Optional[torch.Tensor] = None, position_ids: typing.Optional[torch.LongTensor] = None, past_key_value: typing.Optional[transformers.cache_utils.Cache] = None, output_attentions: typing.Optional[bool] = False, use_cache: typing.Optional[bool] = False, cache_position: typing.Optional[torch.LongTensor] = None, position_embeddings: typing.Optional[tuple[torch.Tensor, torch.Tensor]] = None, kwargs = {} ) -> torch.Tensor
class nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel( config )
Bases: Gemma4PreTrainedModel
_no_split_modules
= ['Gemma4DSparkDecoderLayer']
base_model_prefix
= 'model'
block_size
= int(config.block_size)
embed_tokens
enable_confidence_head
= bool(config.enable_confidence_head)
fc
hidden_norm
layers
lm_head
markov_head
= build_markov_head(config)
mask_token_id
= config.mask_token_id
norm
num_anchors
= int(config.num_anchors)
rotary_emb
target_layer_ids
= config.target_layer_ids
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel._forward_backbone( position_ids: torch.LongTensor, attention_mask: typing.Optional[torch.Tensor] = None, noise_embedding: typing.Optional[torch.Tensor] = None, target_hidden_states: typing.Optional[torch.Tensor] = None, past_key_values: typing.Optional[transformers.cache_utils.Cache] = None, use_cache: bool = False, kwargs = {} ) -> torch.Tensor
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.compute_logits( hidden_states: torch.Tensor ) -> torch.Tensor
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.forward( input_ids: torch.Tensor, target_hidden_states: torch.Tensor, loss_mask: torch.Tensor, target_last_hidden_states: typing.Optional[torch.Tensor] = None ) -> nemo_automodel.components.speculative.dspark.common.DSparkForwardOutput
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.initialize_embeddings_and_head( embed_tokens: torch.nn.Module, lm_head: torch.nn.Module, freeze: bool = True )
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.predict_confidence_step( hidden_states: torch.Tensor, prev_token_ids: typing.Optional[torch.Tensor] = None ) -> typing.Optional[torch.Tensor]
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.sample_draft_token_step( base_logits: torch.Tensor, prev_token_ids: torch.Tensor, temperature: float = 0.0, hidden_states: typing.Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.sample_draft_tokens( base_logits: torch.Tensor, first_prev_token_ids: torch.Tensor, temperature: float = 0.0, hidden_states: typing.Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.set_embedding_head_trainable( trainable: bool )
nemo_automodel.components.speculative.dspark.draft_gemma4.__all__ = ['Gemma4DSparkModel']