nemo_automodel.components.models.gemma4_drafter.composite#

Composite model for joint fine-tuning of a Gemma 4 base + its drafter.

The composite orchestrates a forward pass that:

  1. Runs the base Gemma4ForConditionalGeneration with return_shared_kv_states=True and output_hidden_states=True.

  2. Builds the drafter’s inputs_embeds by concatenating the (already sqrt(H_b)-scaled) base token embeddings with the base’s final hidden state along the feature axis.

  3. Runs the drafter Gemma4AssistantForCausalLM with the captured shared_kv_states and the concatenated embeddings.

  4. Returns a :class:Gemma4JointOutput that exposes both base logits and a per-step list of drafter logits so the training recipe can compute L = L_base + drafter_loss_weight * sum_k L_drafter_k.

Both sub-models are trainable. Gradients from the drafter loss flow back into the base through:

  • the “store” KV layers (last non-shared layer of each layer_type) via shared_kv_states;

  • the base’s input embedding (consumed by the drafter’s first projection);

  • the base’s final hidden state.

This is the EAGLE-2 / Medusa-2 style co-training pattern: the drafter stays aligned with a base that is itself moving.

Module Contents#

Classes#

Gemma4JointOutput

Output of :class:Gemma4WithDrafter.

Gemma4WithDrafter

Composite model that wraps a Gemma 4 base + its released drafter.

Data#

API#

nemo_automodel.components.models.gemma4_drafter.composite.logger#

‘getLogger(…)’

class nemo_automodel.components.models.gemma4_drafter.composite.Gemma4JointOutput#

Output of :class:Gemma4WithDrafter.

.. attribute:: logits

Base model logits [B, S, V].

.. attribute:: drafter_logits

Per-step list of drafter logits, each [B, S, V]. For the default single-step recurrent cell this list has length 1.

.. attribute:: drafter_loss_weight

lambda multiplier the recipe applies to the drafter loss when summing it with the base loss.

.. attribute:: hidden_states

Optional list of base hidden states (mirrors HF).

.. attribute:: loss

Placeholder, populated by the recipe if needed.

logits: torch.Tensor#

None

drafter_logits: list[torch.Tensor]#

‘field(…)’

drafter_loss_weight: float#

1.0

hidden_states: Optional[tuple]#

None

loss: Optional[torch.Tensor]#

None

class nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter(
base: torch.nn.Module,
drafter: torch.nn.Module,
*,
drafter_loss_weight: float = 1.0,
drafter_num_steps: int = 1,
freeze_base_for_drafter: bool = False,
share_embedding_with_base: bool = False,
base_activation_checkpointing: bool = False,
)#

Bases: torch.nn.Module, nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixin

Composite model that wraps a Gemma 4 base + its released drafter.

Both sub-modules are loaded via NeMo’s NeMoAutoModel* paths so they receive the standard distributed infrastructure (FSDP2 sharding, freeze config, checkpoint loading, kernel patches, …) independently. The composite is a thin :class:nn.Module that owns both and exposes a joint forward and a save_pretrained that writes the pair as two HF-format sub-directories (base/ and drafter/).

Parameters:
  • base – Loaded base model (typically a Gemma4ForConditionalGeneration instance returned by NeMoAutoModelForImageTextToText.from_pretrained).

  • drafter – Loaded drafter (a Gemma4DrafterForCausalLM instance returned by NeMoAutoModelForCausalLM.from_pretrained).

  • drafter_loss_weight – Multiplier lambda applied to the drafter loss in the recipe.

  • drafter_num_steps – Number of recurrent drafter steps K to run per training batch. With K = 1 the composite is the EAGLE-1-style single-step setup; with K > 1 the drafter runs autoregressively for K rounds, feeding its previous-round last_hidden_state (already post-projected to H_b) and a teacher-forced shifted token id back into itself, matching the Gemma 4 drafter blog’s recipe. shared_kv_states is captured from a single base forward and reused at every round.

Initialization

supports_gradient_checkpointing#

True

static _get_base_text_config(base: torch.nn.Module)#
classmethod from_pretrained(
base_path: Optional[str] = None,
drafter_path: Optional[str] = None,
*,
pretrained_model_name_or_path: Optional[str] = None,
drafter_loss_weight: float = 1.0,
drafter_num_steps: int = 1,
freeze_base_for_drafter: bool = False,
share_embedding_with_base: bool = False,
base_activation_checkpointing: bool = False,
torch_dtype: Any = None,
attn_implementation: Optional[str] = None,
use_liger_kernel: Optional[bool] = None,
use_sdpa_patching: Optional[bool] = None,
text_config: Optional[dict] = None,
peft_config: Any = None,
device_mesh: Any = None,
moe_mesh: Any = None,
distributed_config: Any = None,
pipeline_config: Any = None,
freeze_config: Any = None,
cache_dir: Optional[str] = None,
**kwargs,
) nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter#

Build the composite by loading base and drafter via the NeMoAuto paths.

Parameters:
  • base_path – HF repo id or local path of the Gemma 4 base model.

  • drafter_path – HF repo id or local path of the released drafter.

  • pretrained_model_name_or_path – Alias for base_path. Kept so that YAML configs can set pretrained_model_name_or_path and have the recipe’s processor / checkpoint-config helpers (which read this key from the model config) keep working.

  • drafter_loss_weightlambda multiplier on the drafter loss.

  • drafter_num_steps – Number of recurrent drafter steps K per batch. K = 1 is EAGLE-1-style single-step; K > 1 matches the Gemma 4 drafter blog’s multi-token-prediction (MTP) training recipe – the drafter consumes its previous round’s post-projected hidden state plus a teacher-forced shifted token id at every subsequent round.

  • freeze_base_for_drafter – If True, freeze all base parameters so only the drafter is trained (drafter-only sub-case). Default False (joint training).

  • share_embedding_with_base – If True, copy the base’s input embedding into the drafter’s embed_tokens once at init. The drafter’s lm_head is tied to its own embed_tokens so the row weights start aligned with the base too. The two embeddings then evolve as independent parameters during training.

  • base_activation_checkpointing – If True, enable HF gradient checkpointing on the base to reduce activation memory. Important for the 4B + drafter + long-context setting.

  • torch_dtype – dtype to use for both sub-models. Must be torch.bfloat16 – the drafter is bf16-only.

  • attn_implementation – Forwarded to both sub-loads.

  • use_liger_kernel – Forwarded to both sub-loads.

  • use_sdpa_patching – Forwarded to both sub-loads.

  • text_config – Optional overrides forwarded to the base load.

  • peft_config – PEFT config (currently expected to be None – joint drafter PEFT is out of scope for the initial recipe).

  • device_mesh – Distributed device mesh shared by base and drafter.

  • moe_mesh – MoE mesh shared by base and drafter (drafter is dense).

  • distributed_config – FSDP2 / Megatron-FSDP / DDP config object.

  • pipeline_config – Must be None – pipeline parallelism is not supported when the drafter is attached.

  • freeze_config – Forwarded to the base only (the drafter is trained end-to-end). Customize the drafter’s freezing with explicit requires_grad_ calls on the returned composite if needed.

  • cache_dir – HuggingFace cache directory.

  • **kwargs – Additional kwargs forwarded to both sub-loads.

Returns:

An instantiated :class:Gemma4WithDrafter.

forward(
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs: Any,
) nemo_automodel.components.models.gemma4_drafter.composite.Gemma4JointOutput#

Joint forward: base first, then drafter consuming the base’s outputs.

Any extra kwargs (pixel_values, mm_token_type_ids, pixel_values_videos, input_features, …) are passed straight through to the base. Multimodal kwargs are not forwarded to the drafter (the drafter is text-only).

get_input_embeddings() torch.nn.Module#
get_output_embeddings() torch.nn.Module#
property config#
property vision_tower#
property audio_tower#
property language_model#
get_rope_index(*args, **kwargs)#
save_pretrained(
save_directory: str,
checkpointer: Optional[nemo_automodel.components.checkpoint.checkpointing.Checkpointer] = None,
tokenizer: Any = None,
**kwargs,
) None#

Save base and drafter as two HF-format sub-directories.

Produces <save_directory>/base/ and <save_directory>/drafter/ with HF-compatible artifacts. Each side can later be loaded back by HF from_pretrained independently (vLLM compatibility).

load_pretrained(
load_directory: str,
checkpointer: Optional[nemo_automodel.components.checkpoint.checkpointing.Checkpointer] = None,
**kwargs,
) None#

Load weights from the two-subdir layout written by save_pretrained.

Mirrors the save side: reads <load_directory>/base/model and <load_directory>/drafter/model (the standard Checkpointer.save_model output layout) and routes them to self.base and self.drafter respectively. Used by the recipe’s resume path when a checkpoint directory was produced by this composite.

Parameters:
  • load_directory – A checkpoint directory containing base/ and drafter/ sub-directories (e.g. <ckpt_dir>/epoch_X_step_Y).

  • checkpointer – The recipe’s :class:Checkpointer instance.

  • **kwargs – Reserved; ignored.

nemo_automodel.components.models.gemma4_drafter.composite.__all__#

[‘Gemma4JointOutput’, ‘Gemma4WithDrafter’]