nemo_automodel.components.models.gemma4_drafter.composite
nemo_automodel.components.models.gemma4_drafter.composite
Composite model for joint fine-tuning of a Gemma 4 base + its drafter.
The composite orchestrates a forward pass that:
- Runs the base
Gemma4ForConditionalGenerationwithreturn_shared_kv_states=Trueandoutput_hidden_states=True. - Builds the drafter’s
inputs_embedsby concatenating the (alreadysqrt(H_b)-scaled) base token embeddings with the base’s final hidden state along the feature axis. - Runs the drafter
Gemma4AssistantForCausalLMwith the capturedshared_kv_statesand the concatenated embeddings. - Returns a :class:
Gemma4JointOutputthat exposes both base logits and a per-step list of drafter logits so the training recipe can computeL = L_base + drafter_loss_weight * sum_k L_drafter_k.
Both sub-models are trainable. Gradients from the drafter loss flow back into the base through:
- the “store” KV layers (last non-shared layer of each
layer_type) viashared_kv_states; - the base’s input embedding (consumed by the drafter’s first projection);
- the base’s final hidden state.
This is the EAGLE-2 / Medusa-2 style co-training pattern: the drafter stays aligned with a base that is itself moving.
Module Contents
Classes
Data
API
Output of :class:Gemma4WithDrafter.
Bases: Module, HFCheckpointingMixin
Composite model that wraps a Gemma 4 base + its released drafter.
Both sub-modules are loaded via NeMo’s NeMoAutoModel* paths so they
receive the standard distributed infrastructure (FSDP2 sharding, freeze
config, checkpoint loading, kernel patches, …) independently. The
composite is a thin :class:nn.Module that owns both and exposes a joint
forward and a save_pretrained that writes the pair as two HF-format
sub-directories (base/ and drafter/).
Parameters:
Loaded base model (typically a Gemma4ForConditionalGeneration
instance returned by NeMoAutoModelForImageTextToText.from_pretrained).
Loaded drafter (a Gemma4DrafterForCausalLM instance
returned by NeMoAutoModelForCausalLM.from_pretrained).
Multiplier lambda applied to the drafter loss
in the recipe.
Number of recurrent drafter steps K to run per
training batch. With K = 1 the composite is the EAGLE-1-style
single-step setup; with K > 1 the drafter runs autoregressively
for K rounds, feeding its previous-round last_hidden_state
(already post-projected to H_b) and a teacher-forced shifted
token id back into itself, matching the Gemma 4 drafter blog’s
recipe. shared_kv_states is captured from a single base
forward and reused at every round.
Joint forward: base first, then drafter consuming the base’s outputs.
Any extra kwargs (pixel_values, mm_token_type_ids,
pixel_values_videos, input_features, …) are passed straight
through to the base. Multimodal kwargs are not forwarded to the
drafter (the drafter is text-only).
Build the composite by loading base and drafter via the NeMoAuto paths.
Parameters:
HF repo id or local path of the Gemma 4 base model.
HF repo id or local path of the released drafter.
Alias for base_path. Kept so that
YAML configs can set pretrained_model_name_or_path and have
the recipe’s processor / checkpoint-config helpers (which read
this key from the model config) keep working.
lambda multiplier on the drafter loss.
Number of recurrent drafter steps K per batch.
K = 1 is EAGLE-1-style single-step; K > 1 matches the
Gemma 4 drafter blog’s multi-token-prediction (MTP) training
recipe — the drafter consumes its previous round’s
post-projected hidden state plus a teacher-forced shifted
token id at every subsequent round.
If True, freeze all base parameters so only the drafter is trained (drafter-only sub-case). Default False (joint training).
If True, copy the base’s input
embedding into the drafter’s embed_tokens once at init.
The drafter’s lm_head is tied to its own embed_tokens
so the row weights start aligned with the base too. The two
embeddings then evolve as independent parameters during
training.
If True, enable HF gradient checkpointing on the base to reduce activation memory. Important for the 4B + drafter + long-context setting.
dtype to use for both sub-models. Must be
torch.bfloat16 — the drafter is bf16-only.
Forwarded to both sub-loads.
Forwarded to both sub-loads.
Forwarded to both sub-loads.
Optional overrides forwarded to the base load.
PEFT config (currently expected to be None —
joint drafter PEFT is out of scope for the initial recipe).
Distributed device mesh shared by base and drafter.
MoE mesh shared by base and drafter (drafter is dense).
FSDP2 / Megatron-FSDP / DDP config object.
Must be None — pipeline parallelism is not
supported when the drafter is attached.
Resolved DistributedSetup (topology + policy)
shared by base and drafter. This is the path used by the VLM
finetune recipe; its pp_size and cp_size must be 1.
Forwarded to the base only (the drafter is trained
end-to-end). Customize the drafter’s freezing with explicit
requires_grad_ calls on the returned composite if needed.
HuggingFace cache directory.
Additional kwargs forwarded to both sub-loads.
Returns: 'Gemma4WithDrafter'
An instantiated :class:Gemma4WithDrafter.
Load weights from the two-subdir layout written by save_pretrained.
Mirrors the save side: reads <load_directory>/base/model and
<load_directory>/drafter/model (the standard Checkpointer.save_model
output layout) and routes them to self.base and self.drafter
respectively. Used by the recipe’s resume path when a checkpoint
directory was produced by this composite.
Parameters:
A checkpoint directory containing base/ and
drafter/ sub-directories (e.g. <ckpt_dir>/epoch_X_step_Y).
The recipe’s :class:Checkpointer instance.
Reserved; ignored.
Save base and drafter as two HF-format sub-directories.
Produces <save_directory>/base/ and <save_directory>/drafter/
with HF-compatible artifacts. Each side can later be loaded back by HF
from_pretrained independently (vLLM compatibility).