nemo_automodel.components.models.gemma4_moe.model
nemo_automodel.components.models.gemma4_moe.model
Gemma4 MoE NeMo Automodel support.
Replaces the HF-native Gemma4 MoE (dense matmul over all experts) with NeMo’s GroupedExperts backend, enabling Expert Parallelism (EP) via the standard MoE parallelizer.
Module Contents
Classes
Functions
Data
API
Bases: HFCheckpointingMixin, HFGemma4ForConditionalGeneration, MoEFSDPSyncMixin
Gemma4 VL conditional generation model with NeMo MoE backend.
When the checkpoint has enable_moe_block=True in its text config,
replaces the HF-native language model with Gemma4MoETextModelBackend
(NeMo GroupedExperts + Gemma4Gate). Otherwise falls through to vanilla HF.
Gemma4-owned CP batch sharder that also self-installs the ring.
Attached to the batch as _cp_make_batch_fn by
prepare_model_inputs_for_cp. cp_utils.make_cp_batch_and_ctx calls it
with the CP submesh, which is the one place Gemma4 receives cp_mesh on a
model-owned path — so install the ring here (idempotent) before sharding,
rather than depending on the framework to call setup_cp_attention.
Return the capabilities for a specific config (no model instance needed).
Dispatches in two layers so the same class can serve every Gemma4 checkpoint honestly:
- If
config.text_config.enable_moe_blockis True → MoE variant (e.g.google/gemma-4-26B-A4B-it). - Else if
config.audio_configis notNone→ dense + audio variant (e.g.google/gemma-4-E2B-it,google/gemma-4-E4B-it). - Else → plain dense variant (e.g.
google/gemma-4-31B-it).
Parameters:
The model’s Gemma4Config (or anything exposing a
text_config with enable_moe_block and an
audio_config attribute).
Returns: ModelCapabilities
A populated :class:ModelCapabilities for this specific config.
Prepare Gemma4 embeddings on the full sequence before CP sharding.
Install Gemma4’s model-owned p2p ring CP attention (dense path).
Idempotent: flips the _cp_enabled flag the forward reads and installs
the ring on every self-attn module (each was given a per-module
setup_cp_attention by attach_gemma4_cp_ring_attention at
construction). Invoked from Gemma4’s own batch-sharding callable
(_cp_shard_batch) the first time the recipe hands it the CP submesh, so
the install is fully model-owned — no framework dispatch is required.
Tie lm_head to the active text embed_tokens when requested.
Overrides HF’s generic tying so that any caller after the MoE
language_model swap (construction, AutoModel, and checkpoint load
via ensure_tied_lm_head) re-points lm_head to the active
embedding rather than whatever HF’s get_input_embeddings()
indirection resolves to. No-op when the config requests untied
embeddings.
Accepts and ignores positional/keyword arguments (e.g. HF v5’s
recompute_mapping) so it stays drop-in compatible with the HF
init_weights() -> tie_weights(...) call path.
The controlling flag is the top-level Gemma4Config.tie_word_embeddings
(verified against HF: the top-level flag decides tying regardless of the
nested text_config value), so read it first and only fall back to
text_config for configs that don’t expose a top-level flag.
Bases: Module
Gemma4 Router reimplemented to output NeMo Gate format.
HF Gemma4Router applies: RMSNorm(no_scale) → root_size scaling → learnable scale → Linear → softmax → top-k → renormalize which is different from the standard Gate class in layer.py. This class reproduces that logic but returns (weights, indices, aux_loss) as expected by GroupedExperts.
Bases: MoE
NeMo MoE that uses Gemma4Gate (with pre-norm routing) instead of
the standard Gate. Subclasses MoE so that isinstance(m, MoE) is True,
which the EP parallelizer relies on.
Forward with optional separate gate input.
HF Gemma4 passes unnormalized residual to the router and normalized
input to the experts. The decoder layer calls this with
gate_input=x (raw residual) so the gate receives unnormalized
input while experts receive pre_feedforward_layernorm_2(x).
Bases: Module
Gemma4 decoder layer with NeMo MoE backend.
Reuses HF attention and dense MLP, replaces HF Router+MoEBlock with NeMo Gemma4MoE (Gemma4Gate + GroupedExperts).
Bases: HFGemma4Model
Thin wrapper that exposes language_model internals as properties
expected by the NeMo training loop.
Bases: Module
Gemma4 text decoder rebuilt with NeMo MoE blocks.
Bases: MutableMapping
A dict-like store for Gemma4 key/value sharing that is safe to pass through FSDP2.
Why a plain dict breaks under FSDP2:
With FSDP2 each decoder layer is wrapped as its own unit, and the default
mixed-precision setting (cast_forward_inputs=True) makes FSDP2 look at
every argument passed to a layer and cast its float tensors to bf16. It
does this with torch’s _apply_to_tensors, which, whenever it sees a
dict (or list/tuple/set/…), builds a brand-new copy of
it. So if the shared store is a plain dict, each layer receives its
own private copy: the earlier layer’s writes land in a copy that is thrown
away, and the later layers read from an empty copy — which raises
KeyError: 'sliding_attention'.
Cache-free holder that lets HF gemma4 kv-sharing fire under use_cache=False.
E2B/E4B share K/V across the trailing num_kv_shared_layers layers: each shared
layer reads its source layer’s K/V from past_key_values.shared_layers (see HF
Gemma4Attention.forward). HF gates that read on past_key_values is not None,
which is None whenever use_cache=False — and use_cache is forced off by
activation checkpointing and the model-owned CP path. The shared layers then fall
back to their (frozen, unused) K/V projections and produce garbage, inflating the
loss ~4x.
Passing this lightweight object as past_key_values satisfies the gate so HF’s
own kv-sharing logic runs: source layers populate shared_layers and shared
layers read it. update is a pass-through (no per-token accumulation, so no cache
memory growth), and get_seq_length returns 0 so the causal mask is built with a
zero cache offset (correct for a training forward).
Build Gemma4 full/sliding masks for packed VLM sequences.
packed_seq_ids contains 1-based document ids and 0 for padding.
Full-attention layers remain plain packed causal attention. Sliding layers
also include Gemma4’s same-image-token bidirectional edges.
Convert a 4D bool allowed-mask to HF additive format (0.0 allowed, -inf masked).
Derive 2D padding mask (True = pad) from 1D, 2D, or 4D attention mask.
True if the (dense) text config uses gemma4 kv-sharing (E2B/E4B).