> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.models.gemma4_drafter.composite

Composite model for joint fine-tuning of a Gemma 4 base + its drafter.

The composite orchestrates a forward pass that:

1. Runs the base `Gemma4ForConditionalGeneration` with
   `return_shared_kv_states=True` and `output_hidden_states=True`.
2. Builds the drafter's `inputs_embeds` by concatenating the (already
   `sqrt(H_b)`-scaled) base token embeddings with the base's final hidden
   state along the feature axis.
3. Runs the drafter `Gemma4AssistantForCausalLM` with the captured
   `shared_kv_states` and the concatenated embeddings.
4. Returns a :class:`Gemma4JointOutput` that exposes both base logits and a
   per-step list of drafter logits so the training recipe can compute
   `L = L_base + drafter_loss_weight * sum_k L_drafter_k`.

Both sub-models are trainable. Gradients from the drafter loss flow back into
the base through:

* the "store" KV layers (last non-shared layer of each `layer_type`) via
  `shared_kv_states`;
* the base's input embedding (consumed by the drafter's first projection);
* the base's final hidden state.

This is the EAGLE-2 / Medusa-2 style co-training pattern: the drafter stays
aligned with a base that is itself moving.

## Module Contents

### Classes

| Name                                                                                                | Description                                                       |
| --------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------- |
| [`Gemma4JointOutput`](#nemo_automodel-components-models-gemma4_drafter-composite-Gemma4JointOutput) | Output of :class:`Gemma4WithDrafter`.                             |
| [`Gemma4WithDrafter`](#nemo_automodel-components-models-gemma4_drafter-composite-Gemma4WithDrafter) | Composite model that wraps a Gemma 4 base + its released drafter. |

### Data

[`__all__`](#nemo_automodel-components-models-gemma4_drafter-composite-__all__)

[`logger`](#nemo_automodel-components-models-gemma4_drafter-composite-logger)

### API

```python
class nemo_automodel.components.models.gemma4_drafter.composite.Gemma4JointOutput(
    logits: torch.Tensor,
    drafter_logits: list[torch.Tensor] = list(),
    drafter_loss_weight: float = 1.0,
    hidden_states: typing.Optional[tuple] = None,
    loss: typing.Optional[torch.Tensor] = None
)
```

Dataclass

Output of :class:`Gemma4WithDrafter`.

```python
class nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter(
    base: torch.nn.Module,
    drafter: torch.nn.Module,
    drafter_loss_weight: float = 1.0,
    drafter_num_steps: int = 1,
    freeze_base_for_drafter: bool = False,
    share_embedding_with_base: bool = False,
    base_activation_checkpointing: bool = False
)
```

**Bases:** `Module`, [HFCheckpointingMixin](/nemo-automodel/nemo_automodel/components/models/common/hf_checkpointing_mixin#nemo_automodel-components-models-common-hf_checkpointing_mixin-HFCheckpointingMixin)

Composite model that wraps a Gemma 4 base + its released drafter.

Both sub-modules are loaded via NeMo's `NeMoAutoModel*` paths so they
receive the standard distributed infrastructure (FSDP2 sharding, freeze
config, checkpoint loading, kernel patches, ...) independently. The
composite is a thin :class:`nn.Module` that owns both and exposes a joint
forward and a `save_pretrained` that writes the pair as two HF-format
sub-directories (`base/` and `drafter/`).

**Parameters:**

Loaded base model (typically a `Gemma4ForConditionalGeneration`
instance returned by `NeMoAutoModelForImageTextToText.from_pretrained`).

Loaded drafter (a `Gemma4DrafterForCausalLM` instance
returned by `NeMoAutoModelForCausalLM.from_pretrained`).

Multiplier `lambda` applied to the drafter loss
in the recipe.

Number of recurrent drafter steps K to run per
training batch. With K = 1 the composite is the EAGLE-1-style
single-step setup; with K > 1 the drafter runs autoregressively
for K rounds, feeding its previous-round `last_hidden_state`
(already post-projected to H\_b) and a teacher-forced shifted
token id back into itself, matching the Gemma 4 drafter blog's
recipe. `shared_kv_states` is captured from a single base
forward and reused at every round.

```python
nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter._get_base_text_config(
    base: torch.nn.Module
)
```

staticmethod

```python
nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter.forward(
    input_ids: typing.Optional[torch.Tensor] = None,
    attention_mask: typing.Optional[torch.Tensor] = None,
    position_ids: typing.Optional[torch.Tensor] = None,
    kwargs: typing.Any = {}
) -> nemo_automodel.components.models.gemma4_drafter.composite.Gemma4JointOutput
```

Joint forward: base first, then drafter consuming the base's outputs.

Any extra kwargs (`pixel_values`, `mm_token_type_ids`,
`pixel_values_videos`, `input_features`, ...) are passed straight
through to the base. Multimodal kwargs are *not* forwarded to the
drafter (the drafter is text-only).

```python
nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter.from_pretrained(
    base_path: typing.Optional[str] = None,
    drafter_path: typing.Optional[str] = None,
    pretrained_model_name_or_path: typing.Optional[str] = None,
    drafter_loss_weight: float = 1.0,
    drafter_num_steps: int = 1,
    freeze_base_for_drafter: bool = False,
    share_embedding_with_base: bool = False,
    base_activation_checkpointing: bool = False,
    torch_dtype: typing.Any = None,
    attn_implementation: typing.Optional[str] = None,
    use_liger_kernel: typing.Optional[bool] = None,
    use_sdpa_patching: typing.Optional[bool] = None,
    text_config: typing.Optional[dict] = None,
    peft_config: typing.Any = None,
    device_mesh: typing.Any = None,
    moe_mesh: typing.Any = None,
    distributed_config: typing.Any = None,
    pipeline_config: typing.Any = None,
    distributed_setup: typing.Any = None,
    freeze_config: typing.Any = None,
    cache_dir: typing.Optional[str] = None,
    kwargs = {}
) -> 'Gemma4WithDrafter'
```

classmethod

Build the composite by loading base and drafter via the NeMoAuto paths.

**Parameters:**

HF repo id or local path of the Gemma 4 base model.

HF repo id or local path of the released drafter.

Alias for `base_path`. Kept so that
YAML configs can set `pretrained_model_name_or_path` and have
the recipe's processor / checkpoint-config helpers (which read
this key from the model config) keep working.

`lambda` multiplier on the drafter loss.

Number of recurrent drafter steps K per batch.
`K = 1` is EAGLE-1-style single-step; `K &gt; 1` matches the
Gemma 4 drafter blog's multi-token-prediction (MTP) training
recipe -- the drafter consumes its previous round's
post-projected hidden state plus a teacher-forced shifted
token id at every subsequent round.

If True, freeze all base parameters so
only the drafter is trained (drafter-only sub-case). Default
False (joint training).

If True, copy the base's input
embedding into the drafter's `embed_tokens` once at init.
The drafter's `lm_head` is tied to its own `embed_tokens`
so the row weights start aligned with the base too. The two
embeddings then evolve as independent parameters during
training.

If True, enable HF gradient
checkpointing on the base to reduce activation memory.
Important for the 4B + drafter + long-context setting.

dtype to use for both sub-models. Must be
`torch.bfloat16` -- the drafter is bf16-only.

Forwarded to both sub-loads.

Forwarded to both sub-loads.

Forwarded to both sub-loads.

Optional overrides forwarded to the base load.

PEFT config (currently expected to be `None` --
joint drafter PEFT is out of scope for the initial recipe).

Distributed device mesh shared by base and drafter.

MoE mesh shared by base and drafter (drafter is dense).

FSDP2 / Megatron-FSDP / DDP config object.

Must be `None` -- pipeline parallelism is not
supported when the drafter is attached.

Resolved `DistributedSetup` (topology + policy)
shared by base and drafter. This is the path used by the VLM
finetune recipe; its `pp_size` and `cp_size` must be `1`.

Forwarded to the base only (the drafter is trained
end-to-end). Customize the drafter's freezing with explicit
`requires_grad_` calls on the returned composite if needed.

HuggingFace cache directory.

Additional kwargs forwarded to both sub-loads.

**Returns:** `'Gemma4WithDrafter'`

An instantiated :class:`Gemma4WithDrafter`.

```python
nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter.get_input_embeddings() -> torch.nn.Module
```

```python
nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter.get_output_embeddings() -> torch.nn.Module
```

```python
nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter.get_rope_index(
    args = (),
    kwargs = {}
)
```

```python
nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter.load_pretrained(
    load_directory: str,
    checkpointer: typing.Optional['Checkpointer'] = None,
    kwargs = {}
) -> None
```

Load weights from the two-subdir layout written by `save_pretrained`.

Mirrors the save side: reads `&lt;load_directory&gt;/base/model` and
`&lt;load_directory&gt;/drafter/model` (the standard `Checkpointer.save_model`
output layout) and routes them to `self.base` and `self.drafter`
respectively. Used by the recipe's resume path when a checkpoint
directory was produced by this composite.

**Parameters:**

A checkpoint directory containing `base/` and
`drafter/` sub-directories (e.g. `&lt;ckpt_dir&gt;/epoch_X_step_Y`).

The recipe's :class:`Checkpointer` instance.

Reserved; ignored.

```python
nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter.save_pretrained(
    save_directory: str,
    checkpointer: typing.Optional['Checkpointer'] = None,
    tokenizer: typing.Any = None,
    kwargs = {}
) -> None
```

Save base and drafter as two HF-format sub-directories.

Produces `&lt;save_directory&gt;/base/` and `&lt;save_directory&gt;/drafter/`
with HF-compatible artifacts. Each side can later be loaded back by HF
`from_pretrained` independently (vLLM compatibility).

```python
nemo_automodel.components.models.gemma4_drafter.composite.__all__ = ['Gemma4JointOutput', 'Gemma4WithDrafter']
```

```python
nemo_automodel.components.models.gemma4_drafter.composite.logger = logging.getLogger(__name__)
```