nemo_automodel.components.models.gemma4_drafter.composite

View as Markdown

Composite model for joint fine-tuning of a Gemma 4 base + its drafter.

The composite orchestrates a forward pass that:

  1. Runs the base Gemma4ForConditionalGeneration with return_shared_kv_states=True and output_hidden_states=True.
  2. Builds the drafter’s inputs_embeds by concatenating the (already sqrt(H_b)-scaled) base token embeddings with the base’s final hidden state along the feature axis.
  3. Runs the drafter Gemma4AssistantForCausalLM with the captured shared_kv_states and the concatenated embeddings.
  4. Returns a :class:Gemma4JointOutput that exposes both base logits and a per-step list of drafter logits so the training recipe can compute L = L_base + drafter_loss_weight * sum_k L_drafter_k.

Both sub-models are trainable. Gradients from the drafter loss flow back into the base through:

  • the “store” KV layers (last non-shared layer of each layer_type) via shared_kv_states;
  • the base’s input embedding (consumed by the drafter’s first projection);
  • the base’s final hidden state.

This is the EAGLE-2 / Medusa-2 style co-training pattern: the drafter stays aligned with a base that is itself moving.

Module Contents

Classes

NameDescription
Gemma4JointOutputOutput of :class:Gemma4WithDrafter.
Gemma4WithDrafterComposite model that wraps a Gemma 4 base + its released drafter.

Data

__all__

logger

API

class nemo_automodel.components.models.gemma4_drafter.composite.Gemma4JointOutput(
logits: torch.Tensor,
drafter_logits: list[torch.Tensor] = list(),
drafter_loss_weight: float = 1.0,
hidden_states: typing.Optional[tuple] = None,
loss: typing.Optional[torch.Tensor] = None
)
Dataclass

Output of :class:Gemma4WithDrafter.

drafter_logits
list[Tensor] = field(default_factory=list)
drafter_loss_weight
float = 1.0
hidden_states
Optional[tuple] = None
logits
Tensor
loss
Optional[Tensor] = None
class nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter(
base: torch.nn.Module,
drafter: torch.nn.Module,
drafter_loss_weight: float = 1.0,
drafter_num_steps: int = 1,
freeze_base_for_drafter: bool = False,
share_embedding_with_base: bool = False,
base_activation_checkpointing: bool = False
)

Bases: Module, HFCheckpointingMixin

Composite model that wraps a Gemma 4 base + its released drafter.

Both sub-modules are loaded via NeMo’s NeMoAutoModel* paths so they receive the standard distributed infrastructure (FSDP2 sharding, freeze config, checkpoint loading, kernel patches, …) independently. The composite is a thin :class:nn.Module that owns both and exposes a joint forward and a save_pretrained that writes the pair as two HF-format sub-directories (base/ and drafter/).

Parameters:

base
nn.Module

Loaded base model (typically a Gemma4ForConditionalGeneration instance returned by NeMoAutoModelForImageTextToText.from_pretrained).

drafter
nn.Module

Loaded drafter (a Gemma4DrafterForCausalLM instance returned by NeMoAutoModelForCausalLM.from_pretrained).

drafter_loss_weight
floatDefaults to 1.0

Multiplier lambda applied to the drafter loss in the recipe.

drafter_num_steps
intDefaults to 1

Number of recurrent drafter steps K to run per training batch. With K = 1 the composite is the EAGLE-1-style single-step setup; with K > 1 the drafter runs autoregressively for K rounds, feeding its previous-round last_hidden_state (already post-projected to H_b) and a teacher-forced shifted token id back into itself, matching the Gemma 4 drafter blog’s recipe. shared_kv_states is captured from a single base forward and reused at every round.

base_activation_checkpointing
= bool(base_activation_checkpointing)
drafter_loss_weight
= float(drafter_loss_weight)
drafter_num_steps
= int(drafter_num_steps)
freeze_base_for_drafter
= bool(freeze_base_for_drafter)
share_embedding_with_base
= bool(share_embedding_with_base)
nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter._get_base_text_config(
base: torch.nn.Module
)
staticmethod
nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter.forward(
input_ids: typing.Optional[torch.Tensor] = None,
attention_mask: typing.Optional[torch.Tensor] = None,
position_ids: typing.Optional[torch.Tensor] = None,
kwargs: typing.Any = {}
) -> nemo_automodel.components.models.gemma4_drafter.composite.Gemma4JointOutput

Joint forward: base first, then drafter consuming the base’s outputs.

Any extra kwargs (pixel_values, mm_token_type_ids, pixel_values_videos, input_features, …) are passed straight through to the base. Multimodal kwargs are not forwarded to the drafter (the drafter is text-only).

nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter.from_pretrained(
base_path: typing.Optional[str] = None,
drafter_path: typing.Optional[str] = None,
pretrained_model_name_or_path: typing.Optional[str] = None,
drafter_loss_weight: float = 1.0,
drafter_num_steps: int = 1,
freeze_base_for_drafter: bool = False,
share_embedding_with_base: bool = False,
base_activation_checkpointing: bool = False,
torch_dtype: typing.Any = None,
attn_implementation: typing.Optional[str] = None,
use_liger_kernel: typing.Optional[bool] = None,
use_sdpa_patching: typing.Optional[bool] = None,
text_config: typing.Optional[dict] = None,
peft_config: typing.Any = None,
device_mesh: typing.Any = None,
moe_mesh: typing.Any = None,
distributed_config: typing.Any = None,
pipeline_config: typing.Any = None,
distributed_setup: typing.Any = None,
freeze_config: typing.Any = None,
cache_dir: typing.Optional[str] = None,
kwargs = {}
) -> 'Gemma4WithDrafter'
classmethod

Build the composite by loading base and drafter via the NeMoAuto paths.

Parameters:

base_path
Optional[str]Defaults to None

HF repo id or local path of the Gemma 4 base model.

drafter_path
Optional[str]Defaults to None

HF repo id or local path of the released drafter.

pretrained_model_name_or_path
Optional[str]Defaults to None

Alias for base_path. Kept so that YAML configs can set pretrained_model_name_or_path and have the recipe’s processor / checkpoint-config helpers (which read this key from the model config) keep working.

drafter_loss_weight
floatDefaults to 1.0

lambda multiplier on the drafter loss.

drafter_num_steps
intDefaults to 1

Number of recurrent drafter steps K per batch. K = 1 is EAGLE-1-style single-step; K > 1 matches the Gemma 4 drafter blog’s multi-token-prediction (MTP) training recipe — the drafter consumes its previous round’s post-projected hidden state plus a teacher-forced shifted token id at every subsequent round.

freeze_base_for_drafter
boolDefaults to False

If True, freeze all base parameters so only the drafter is trained (drafter-only sub-case). Default False (joint training).

share_embedding_with_base
boolDefaults to False

If True, copy the base’s input embedding into the drafter’s embed_tokens once at init. The drafter’s lm_head is tied to its own embed_tokens so the row weights start aligned with the base too. The two embeddings then evolve as independent parameters during training.

base_activation_checkpointing
boolDefaults to False

If True, enable HF gradient checkpointing on the base to reduce activation memory. Important for the 4B + drafter + long-context setting.

torch_dtype
AnyDefaults to None

dtype to use for both sub-models. Must be torch.bfloat16 — the drafter is bf16-only.

attn_implementation
Optional[str]Defaults to None

Forwarded to both sub-loads.

use_liger_kernel
Optional[bool]Defaults to None

Forwarded to both sub-loads.

use_sdpa_patching
Optional[bool]Defaults to None

Forwarded to both sub-loads.

text_config
Optional[dict]Defaults to None

Optional overrides forwarded to the base load.

peft_config
AnyDefaults to None

PEFT config (currently expected to be None — joint drafter PEFT is out of scope for the initial recipe).

device_mesh
AnyDefaults to None

Distributed device mesh shared by base and drafter.

moe_mesh
AnyDefaults to None

MoE mesh shared by base and drafter (drafter is dense).

distributed_config
AnyDefaults to None

FSDP2 / Megatron-FSDP / DDP config object.

pipeline_config
AnyDefaults to None

Must be None — pipeline parallelism is not supported when the drafter is attached.

distributed_setup
AnyDefaults to None

Resolved DistributedSetup (topology + policy) shared by base and drafter. This is the path used by the VLM finetune recipe; its pp_size and cp_size must be 1.

freeze_config
AnyDefaults to None

Forwarded to the base only (the drafter is trained end-to-end). Customize the drafter’s freezing with explicit requires_grad_ calls on the returned composite if needed.

cache_dir
Optional[str]Defaults to None

HuggingFace cache directory.

**kwargs
Defaults to {}

Additional kwargs forwarded to both sub-loads.

Returns: 'Gemma4WithDrafter'

An instantiated :class:Gemma4WithDrafter.

nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter.get_input_embeddings() -> torch.nn.Module
nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter.get_output_embeddings() -> torch.nn.Module
nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter.get_rope_index(
args = (),
kwargs = {}
)
nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter.load_pretrained(
load_directory: str,
checkpointer: typing.Optional['Checkpointer'] = None,
kwargs = {}
) -> None

Load weights from the two-subdir layout written by save_pretrained.

Mirrors the save side: reads <load_directory>/base/model and <load_directory>/drafter/model (the standard Checkpointer.save_model output layout) and routes them to self.base and self.drafter respectively. Used by the recipe’s resume path when a checkpoint directory was produced by this composite.

Parameters:

load_directory
str

A checkpoint directory containing base/ and drafter/ sub-directories (e.g. <ckpt_dir>/epoch_X_step_Y).

checkpointer
Optional['Checkpointer']Defaults to None

The recipe’s :class:Checkpointer instance.

**kwargs
Defaults to {}

Reserved; ignored.

nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter.save_pretrained(
save_directory: str,
checkpointer: typing.Optional['Checkpointer'] = None,
tokenizer: typing.Any = None,
kwargs = {}
) -> None

Save base and drafter as two HF-format sub-directories.

Produces <save_directory>/base/ and <save_directory>/drafter/ with HF-compatible artifacts. Each side can later be loaded back by HF from_pretrained independently (vLLM compatibility).

nemo_automodel.components.models.gemma4_drafter.composite.__all__ = ['Gemma4JointOutput', 'Gemma4WithDrafter']
nemo_automodel.components.models.gemma4_drafter.composite.logger = logging.getLogger(__name__)