nemo_automodel.components.models.gemma4_drafter.composite#
Composite model for joint fine-tuning of a Gemma 4 base + its drafter.
The composite orchestrates a forward pass that:
Runs the base
Gemma4ForConditionalGenerationwithreturn_shared_kv_states=Trueandoutput_hidden_states=True.Builds the drafter’s
inputs_embedsby concatenating the (alreadysqrt(H_b)-scaled) base token embeddings with the base’s final hidden state along the feature axis.Runs the drafter
Gemma4AssistantForCausalLMwith the capturedshared_kv_statesand the concatenated embeddings.Returns a :class:
Gemma4JointOutputthat exposes both base logits and a per-step list of drafter logits so the training recipe can computeL = L_base + drafter_loss_weight * sum_k L_drafter_k.
Both sub-models are trainable. Gradients from the drafter loss flow back into the base through:
the “store” KV layers (last non-shared layer of each
layer_type) viashared_kv_states;the base’s input embedding (consumed by the drafter’s first projection);
the base’s final hidden state.
This is the EAGLE-2 / Medusa-2 style co-training pattern: the drafter stays aligned with a base that is itself moving.
Module Contents#
Classes#
Output of :class: |
|
Composite model that wraps a Gemma 4 base + its released drafter. |
Data#
API#
- nemo_automodel.components.models.gemma4_drafter.composite.logger#
‘getLogger(…)’
- class nemo_automodel.components.models.gemma4_drafter.composite.Gemma4JointOutput#
Output of :class:
Gemma4WithDrafter... attribute:: logits
Base model logits
[B, S, V]... attribute:: drafter_logits
Per-step list of drafter logits, each
[B, S, V]. For the default single-step recurrent cell this list has length 1... attribute:: drafter_loss_weight
lambdamultiplier the recipe applies to the drafter loss when summing it with the base loss... attribute:: hidden_states
Optional list of base hidden states (mirrors HF).
.. attribute:: loss
Placeholder, populated by the recipe if needed.
- logits: torch.Tensor#
None
- drafter_logits: list[torch.Tensor]#
‘field(…)’
- drafter_loss_weight: float#
1.0
None
- loss: Optional[torch.Tensor]#
None
- class nemo_automodel.components.models.gemma4_drafter.composite.Gemma4WithDrafter(
- base: torch.nn.Module,
- drafter: torch.nn.Module,
- *,
- drafter_loss_weight: float = 1.0,
- drafter_num_steps: int = 1,
- freeze_base_for_drafter: bool = False,
- share_embedding_with_base: bool = False,
- base_activation_checkpointing: bool = False,
Bases:
torch.nn.Module,nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixinComposite model that wraps a Gemma 4 base + its released drafter.
Both sub-modules are loaded via NeMo’s
NeMoAutoModel*paths so they receive the standard distributed infrastructure (FSDP2 sharding, freeze config, checkpoint loading, kernel patches, …) independently. The composite is a thin :class:nn.Modulethat owns both and exposes a joint forward and asave_pretrainedthat writes the pair as two HF-format sub-directories (base/anddrafter/).- Parameters:
base – Loaded base model (typically a
Gemma4ForConditionalGenerationinstance returned byNeMoAutoModelForImageTextToText.from_pretrained).drafter – Loaded drafter (a
Gemma4DrafterForCausalLMinstance returned byNeMoAutoModelForCausalLM.from_pretrained).drafter_loss_weight – Multiplier
lambdaapplied to the drafter loss in the recipe.drafter_num_steps – Number of recurrent drafter steps K to run per training batch. With K = 1 the composite is the EAGLE-1-style single-step setup; with K > 1 the drafter runs autoregressively for K rounds, feeding its previous-round
last_hidden_state(already post-projected to H_b) and a teacher-forced shifted token id back into itself, matching the Gemma 4 drafter blog’s recipe.shared_kv_statesis captured from a single base forward and reused at every round.
Initialization
- supports_gradient_checkpointing#
True
- static _get_base_text_config(base: torch.nn.Module)#
- classmethod from_pretrained(
- base_path: Optional[str] = None,
- drafter_path: Optional[str] = None,
- *,
- pretrained_model_name_or_path: Optional[str] = None,
- drafter_loss_weight: float = 1.0,
- drafter_num_steps: int = 1,
- freeze_base_for_drafter: bool = False,
- share_embedding_with_base: bool = False,
- base_activation_checkpointing: bool = False,
- torch_dtype: Any = None,
- attn_implementation: Optional[str] = None,
- use_liger_kernel: Optional[bool] = None,
- use_sdpa_patching: Optional[bool] = None,
- text_config: Optional[dict] = None,
- peft_config: Any = None,
- device_mesh: Any = None,
- moe_mesh: Any = None,
- distributed_config: Any = None,
- pipeline_config: Any = None,
- freeze_config: Any = None,
- cache_dir: Optional[str] = None,
- **kwargs,
Build the composite by loading base and drafter via the NeMoAuto paths.
- Parameters:
base_path – HF repo id or local path of the Gemma 4 base model.
drafter_path – HF repo id or local path of the released drafter.
pretrained_model_name_or_path – Alias for
base_path. Kept so that YAML configs can setpretrained_model_name_or_pathand have the recipe’s processor / checkpoint-config helpers (which read this key from the model config) keep working.drafter_loss_weight –
lambdamultiplier on the drafter loss.drafter_num_steps – Number of recurrent drafter steps K per batch.
K = 1is EAGLE-1-style single-step;K > 1matches the Gemma 4 drafter blog’s multi-token-prediction (MTP) training recipe – the drafter consumes its previous round’s post-projected hidden state plus a teacher-forced shifted token id at every subsequent round.freeze_base_for_drafter – If True, freeze all base parameters so only the drafter is trained (drafter-only sub-case). Default False (joint training).
share_embedding_with_base – If True, copy the base’s input embedding into the drafter’s
embed_tokensonce at init. The drafter’slm_headis tied to its ownembed_tokensso the row weights start aligned with the base too. The two embeddings then evolve as independent parameters during training.base_activation_checkpointing – If True, enable HF gradient checkpointing on the base to reduce activation memory. Important for the 4B + drafter + long-context setting.
torch_dtype – dtype to use for both sub-models. Must be
torch.bfloat16– the drafter is bf16-only.attn_implementation – Forwarded to both sub-loads.
use_liger_kernel – Forwarded to both sub-loads.
use_sdpa_patching – Forwarded to both sub-loads.
text_config – Optional overrides forwarded to the base load.
peft_config – PEFT config (currently expected to be
None– joint drafter PEFT is out of scope for the initial recipe).device_mesh – Distributed device mesh shared by base and drafter.
moe_mesh – MoE mesh shared by base and drafter (drafter is dense).
distributed_config – FSDP2 / Megatron-FSDP / DDP config object.
pipeline_config – Must be
None– pipeline parallelism is not supported when the drafter is attached.freeze_config – Forwarded to the base only (the drafter is trained end-to-end). Customize the drafter’s freezing with explicit
requires_grad_calls on the returned composite if needed.cache_dir – HuggingFace cache directory.
**kwargs – Additional kwargs forwarded to both sub-loads.
- Returns:
An instantiated :class:
Gemma4WithDrafter.
- forward(
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- **kwargs: Any,
Joint forward: base first, then drafter consuming the base’s outputs.
Any extra kwargs (
pixel_values,mm_token_type_ids,pixel_values_videos,input_features, …) are passed straight through to the base. Multimodal kwargs are not forwarded to the drafter (the drafter is text-only).
- get_input_embeddings() torch.nn.Module#
- get_output_embeddings() torch.nn.Module#
- property config#
- property vision_tower#
- property audio_tower#
- property language_model#
- get_rope_index(*args, **kwargs)#
- save_pretrained(
- save_directory: str,
- checkpointer: Optional[nemo_automodel.components.checkpoint.checkpointing.Checkpointer] = None,
- tokenizer: Any = None,
- **kwargs,
Save base and drafter as two HF-format sub-directories.
Produces
<save_directory>/base/and<save_directory>/drafter/with HF-compatible artifacts. Each side can later be loaded back by HFfrom_pretrainedindependently (vLLM compatibility).
- load_pretrained(
- load_directory: str,
- checkpointer: Optional[nemo_automodel.components.checkpoint.checkpointing.Checkpointer] = None,
- **kwargs,
Load weights from the two-subdir layout written by
save_pretrained.Mirrors the save side: reads
<load_directory>/base/modeland<load_directory>/drafter/model(the standardCheckpointer.save_modeloutput layout) and routes them toself.baseandself.drafterrespectively. Used by the recipe’s resume path when a checkpoint directory was produced by this composite.- Parameters:
load_directory – A checkpoint directory containing
base/anddrafter/sub-directories (e.g.<ckpt_dir>/epoch_X_step_Y).checkpointer – The recipe’s :class:
Checkpointerinstance.**kwargs – Reserved; ignored.
- nemo_automodel.components.models.gemma4_drafter.composite.__all__#
[‘Gemma4JointOutput’, ‘Gemma4WithDrafter’]