> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.speculative.dspark.draft_gemma4

## Module Contents

### Classes

| Name                                                                                                              | Description |
| ----------------------------------------------------------------------------------------------------------------- | ----------- |
| [`Gemma4DSparkAttention`](#nemo_automodel-components-speculative-dspark-draft_gemma4-Gemma4DSparkAttention)       | -           |
| [`Gemma4DSparkDecoderLayer`](#nemo_automodel-components-speculative-dspark-draft_gemma4-Gemma4DSparkDecoderLayer) | -           |
| [`Gemma4DSparkModel`](#nemo_automodel-components-speculative-dspark-draft_gemma4-Gemma4DSparkModel)               | -           |

### Data

[`__all__`](#nemo_automodel-components-speculative-dspark-draft_gemma4-__all__)

### API

```python
class nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkAttention(
    config,
    layer_idx: int
)
```

**Bases:** `Module`

```python
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkAttention._repeat_kv(
    hidden_states: torch.Tensor
) -> torch.Tensor
```

```python
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkAttention.forward(
    hidden_states: torch.Tensor,
    target_hidden_states: torch.Tensor,
    position_embeddings: tuple[torch.Tensor, torch.Tensor],
    attention_mask: typing.Optional[torch.Tensor],
    past_key_values: typing.Optional[transformers.cache_utils.Cache] = None,
    cache_position: typing.Optional[torch.LongTensor] = None,
    kwargs = {}
) -> tuple[torch.Tensor, typing.Optional[torch.Tensor]]
```

```python
class nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkDecoderLayer(
    config,
    layer_idx: int
)
```

**Bases:** `GradientCheckpointingLayer`

```python
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkDecoderLayer.forward(
    target_hidden_states: typing.Optional[torch.Tensor] = None,
    hidden_states: typing.Optional[torch.Tensor] = None,
    attention_mask: typing.Optional[torch.Tensor] = None,
    position_ids: typing.Optional[torch.LongTensor] = None,
    past_key_value: typing.Optional[transformers.cache_utils.Cache] = None,
    output_attentions: typing.Optional[bool] = False,
    use_cache: typing.Optional[bool] = False,
    cache_position: typing.Optional[torch.LongTensor] = None,
    position_embeddings: typing.Optional[tuple[torch.Tensor, torch.Tensor]] = None,
    kwargs = {}
) -> torch.Tensor
```

```python
class nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel(
    config
)
```

**Bases:** `Gemma4PreTrainedModel`

```python
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel._forward_backbone(
    position_ids: torch.LongTensor,
    attention_mask: typing.Optional[torch.Tensor] = None,
    noise_embedding: typing.Optional[torch.Tensor] = None,
    target_hidden_states: typing.Optional[torch.Tensor] = None,
    past_key_values: typing.Optional[transformers.cache_utils.Cache] = None,
    use_cache: bool = False,
    kwargs = {}
) -> torch.Tensor
```

```python
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.compute_logits(
    hidden_states: torch.Tensor
) -> torch.Tensor
```

```python
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.forward(
    input_ids: torch.Tensor,
    target_hidden_states: torch.Tensor,
    loss_mask: torch.Tensor,
    target_last_hidden_states: typing.Optional[torch.Tensor] = None
) -> nemo_automodel.components.speculative.dspark.common.DSparkForwardOutput
```

```python
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.initialize_embeddings_and_head(
    embed_tokens: torch.nn.Module,
    lm_head: torch.nn.Module,
    freeze: bool = True
)
```

```python
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.predict_confidence_step(
    hidden_states: torch.Tensor,
    prev_token_ids: typing.Optional[torch.Tensor] = None
) -> typing.Optional[torch.Tensor]
```

```python
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.sample_draft_token_step(
    base_logits: torch.Tensor,
    prev_token_ids: torch.Tensor,
    temperature: float = 0.0,
    hidden_states: typing.Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]
```

```python
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.sample_draft_tokens(
    base_logits: torch.Tensor,
    first_prev_token_ids: torch.Tensor,
    temperature: float = 0.0,
    hidden_states: typing.Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]
```

```python
nemo_automodel.components.speculative.dspark.draft_gemma4.Gemma4DSparkModel.set_embedding_head_trainable(
    trainable: bool
)
```

```python
nemo_automodel.components.speculative.dspark.draft_gemma4.__all__ = ['Gemma4DSparkModel']
```