nemo_automodel.components.models.gemma4_moe.model

View as Markdown

Gemma4 MoE NeMo Automodel support.

Replaces the HF-native Gemma4 MoE (dense matmul over all experts) with NeMo’s GroupedExperts backend, enabling Expert Parallelism (EP) via the standard MoE parallelizer.

Module Contents

Classes

NameDescription
Gemma4ForConditionalGeneration-
Gemma4GateGemma4 Router reimplemented to output NeMo Gate format.
Gemma4MoENeMo MoE that uses Gemma4Gate (with pre-norm routing) instead of
Gemma4MoEDecoderLayerGemma4 decoder layer with NeMo MoE backend.
Gemma4MoEModelThin wrapper that exposes language_model internals as properties
Gemma4MoETextModelBackendGemma4 text decoder rebuilt with NeMo MoE blocks.
_FSDPSafeSharedKVStatesA dict-like store for Gemma4 key/value sharing that is safe to pass through FSDP2.
_Gemma4KVShareHolderCache-free holder that lets HF gemma4 kv-sharing fire under use_cache=False.

Functions

NameDescription
_build_packed_gemma4_causal_mask_mappingBuild Gemma4 full/sliding masks for packed VLM sequences.
_convert_bool_4d_mask_to_additiveConvert a 4D bool allowed-mask to HF additive format (0.0 allowed, -inf masked).
_derive_padding_maskDerive 2D padding mask (True = pad) from 1D, 2D, or 4D attention mask.
_kv_sharing_activeTrue if the (dense) text config uses gemma4 kv-sharing (E2B/E4B).
_make_missing-

Data

Gemma4Attention

Gemma4CausalLMOutputWithPast

Gemma4DecoderLayer

Gemma4MLP

Gemma4RMSNorm

Gemma4RotaryEmbedding

Gemma4TextModel

Gemma4TextScaledWordEmbedding

ModelClass

_GEMMA4_HF_AVAILABLE

API

class nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration(
config: transformers.models.gemma4.configuration_gemma4.Gemma4Config,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
text_config: dict | None = None,
kwargs = {}
)

Bases: HFCheckpointingMixin, HFGemma4ForConditionalGeneration, MoEFSDPSyncMixin

_keep_in_fp32_modules
= ['rotary_emb']

Gemma4 VL conditional generation model with NeMo MoE backend.

When the checkpoint has enable_moe_block=True in its text config, replaces the HF-native language model with Gemma4MoETextModelBackend (NeMo GroupedExperts + Gemma4Gate). Otherwise falls through to vanilla HF.

pad_token_id
= pad_token_id if pad_token_id is not None else -1
state_dict_adapter
vocab_size
= text_config.vocab_size
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration._cp_shard_batch(
cp_mesh,
tp_mesh,
batch,
loss_mask = None,
padding_token_id = 0
)

Gemma4-owned CP batch sharder that also self-installs the ring.

Attached to the batch as _cp_make_batch_fn by prepare_model_inputs_for_cp. cp_utils.make_cp_batch_and_ctx calls it with the CP submesh, which is the one place Gemma4 receives cp_mesh on a model-owned path — so install the ring here (idempotent) before sharding, rather than depending on the framework to call setup_cp_attention.

nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration._get_special_image_mask(
input_ids: torch.Tensor,
mm_token_type_ids: torch.Tensor | None = None
) -> torch.Tensor
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration._get_text_pad_token_id() -> int
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration._prepare_per_layer_inputs_for_cp(
input_ids: torch.Tensor,
special_image_mask: torch.Tensor
) -> torch.Tensor | None
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.forward(
input_ids: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
cache_position: torch.Tensor | None = None,
pixel_values: torch.Tensor | None = None,
image_position_ids: torch.Tensor | None = None,
mm_token_type_ids: torch.Tensor | None = None,
_pre_embed_only: bool = False,
logits_to_keep: typing.Union[int, torch.Tensor] = 0,
output_hidden_states: typing.Optional[bool] = None,
kwargs: typing.Any = {}
)
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.from_config(
config: transformers.models.gemma4.configuration_gemma4.Gemma4Config,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)
classmethod
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.from_pretrained(
pretrained_model_name_or_path: str,
model_args = (),
kwargs = {}
)
classmethod
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.get_capabilities(
config: transformers.models.gemma4.configuration_gemma4.Gemma4Config
) -> nemo_automodel._transformers.model_capabilities.ModelCapabilities
classmethod

Return the capabilities for a specific config (no model instance needed).

Dispatches in two layers so the same class can serve every Gemma4 checkpoint honestly:

  1. If config.text_config.enable_moe_block is True → MoE variant (e.g. google/gemma-4-26B-A4B-it).
  2. Else if config.audio_config is not None → dense + audio variant (e.g. google/gemma-4-E2B-it, google/gemma-4-E4B-it).
  3. Else → plain dense variant (e.g. google/gemma-4-31B-it).

Parameters:

config
Gemma4Config

The model’s Gemma4Config (or anything exposing a text_config with enable_moe_block and an audio_config attribute).

Returns: ModelCapabilities

A populated :class:ModelCapabilities for this specific config.

nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.initialize_weights(
buffer_device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16
) -> None
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.prepare_inputs_embeds_for_cp(
input_ids: torch.Tensor,
pixel_values: torch.Tensor | None = None,
image_position_ids: torch.Tensor | None = None,
mm_token_type_ids: torch.Tensor | None = None
) -> torch.Tensor
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.prepare_model_inputs_for_cp(
input_ids: torch.Tensor,
pixel_values: torch.Tensor | None = None,
image_position_ids: torch.Tensor | None = None,
mm_token_type_ids: torch.Tensor | None = None
) -> dict[str, typing.Any]

Prepare Gemma4 embeddings on the full sequence before CP sharding.

nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.setup_cp_attention(
cp_mesh
) -> None

Install Gemma4’s model-owned p2p ring CP attention (dense path).

Idempotent: flips the _cp_enabled flag the forward reads and installs the ring on every self-attn module (each was given a per-module setup_cp_attention by attach_gemma4_cp_ring_attention at construction). Invoked from Gemma4’s own batch-sharding callable (_cp_shard_batch) the first time the recipe hands it the CP submesh, so the install is fully model-owned — no framework dispatch is required.

nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.tie_weights(
_args: object = (),
_kwargs: object = {}
) -> None

Tie lm_head to the active text embed_tokens when requested.

Overrides HF’s generic tying so that any caller after the MoE language_model swap (construction, AutoModel, and checkpoint load via ensure_tied_lm_head) re-points lm_head to the active embedding rather than whatever HF’s get_input_embeddings() indirection resolves to. No-op when the config requests untied embeddings.

Accepts and ignores positional/keyword arguments (e.g. HF v5’s recompute_mapping) so it stays drop-in compatible with the HF init_weights() -> tie_weights(...) call path.

The controlling flag is the top-level Gemma4Config.tie_word_embeddings (verified against HF: the top-level flag decides tying regardless of the nested text_config value), so read it first and only fall back to text_config for configs that don’t expose a top-level flag.

class nemo_automodel.components.models.gemma4_moe.model.Gemma4Gate(
config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig
)

Bases: Module

Gemma4 Router reimplemented to output NeMo Gate format.

HF Gemma4Router applies: RMSNorm(no_scale) → root_size scaling → learnable scale → Linear → softmax → top-k → renormalize which is different from the standard Gate class in layer.py. This class reproduces that logic but returns (weights, indices, aux_loss) as expected by GroupedExperts.

norm
proj
scale
= nn.Parameter(torch.ones(hidden_size, dtype=dtype))
topk
= config.top_k_experts
nemo_automodel.components.models.gemma4_moe.model.Gemma4Gate.forward(
x,
token_mask = None,
cp_mesh = None
)
nemo_automodel.components.models.gemma4_moe.model.Gemma4Gate.init_weights(
buffer_device: torch.device,
init_std: float = 0.02
) -> None
class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoE(
moe_config: nemo_automodel.components.moe.layers.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
text_config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig
)

Bases: MoE

NeMo MoE that uses Gemma4Gate (with pre-norm routing) instead of the standard Gate. Subclasses MoE so that isinstance(m, MoE) is True, which the EP parallelizer relies on.

gate
= Gemma4Gate(text_config)
nemo_automodel.components.models.gemma4_moe.model.Gemma4MoE.forward(
x,
padding_mask = None,
cp_mesh = None,
gate_input = None
)

Forward with optional separate gate input.

HF Gemma4 passes unnormalized residual to the router and normalized input to the experts. The decoder layer calls this with gate_input=x (raw residual) so the gate receives unnormalized input while experts receive pre_feedforward_layernorm_2(x).

class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoEDecoderLayer(
config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig,
layer_idx: int,
moe_config: nemo_automodel.components.moe.layers.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: Module

Gemma4 decoder layer with NeMo MoE backend.

Reuses HF attention and dense MLP, replaces HF Router+MoEBlock with NeMo Gemma4MoE (Gemma4Gate + GroupedExperts).

attention_type
= config.layer_types[layer_idx]
hidden_size
= config.hidden_size
input_layernorm
mlp
= Gemma4MLP(config, layer_idx)
moe
= Gemma4MoE(moe_config, backend, config)
post_attention_layernorm
post_feedforward_layernorm
post_feedforward_layernorm_1
post_feedforward_layernorm_2
pre_feedforward_layernorm
pre_feedforward_layernorm_2
self_attn
nemo_automodel.components.models.gemma4_moe.model.Gemma4MoEDecoderLayer.forward(
x: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
padding_mask: torch.Tensor | None = None,
past_key_values = None,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
mm_token_type_ids: torch.Tensor | None = None,
shared_kv_states: dict[str, tuple[torch.Tensor, torch.Tensor]] | None = None,
kwargs: typing.Any = {}
) -> torch.Tensor
class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoEModel()

Bases: HFGemma4Model

Thin wrapper that exposes language_model internals as properties expected by the NeMo training loop.

class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoETextModelBackend(
config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
moe_overrides: dict | None = None
)

Bases: Module

Gemma4 text decoder rebuilt with NeMo MoE blocks.

embed_tokens
layers
moe_config
= moe_config or MoEConfig(**moe_defaults)
norm
padding_idx
= getattr(config, 'pad_token_id', None)
rotary_emb
= Gemma4RotaryEmbedding(config)
vocab_size
= config.vocab_size
nemo_automodel.components.models.gemma4_moe.model.Gemma4MoETextModelBackend.forward(
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
cache_position: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
mm_token_type_ids: torch.Tensor | None = None,
pixel_values: torch.Tensor | None = None,
past_key_values = None,
use_cache: bool | None = None,
cp_enabled: bool = False,
kwargs: typing.Any = {}
) -> transformers.modeling_outputs.BaseModelOutputWithPast
nemo_automodel.components.models.gemma4_moe.model.Gemma4MoETextModelBackend.get_input_embeddings() -> torch.nn.Module
nemo_automodel.components.models.gemma4_moe.model.Gemma4MoETextModelBackend.set_input_embeddings(
value: torch.nn.Module
) -> None
class nemo_automodel.components.models.gemma4_moe.model._FSDPSafeSharedKVStates()

Bases: MutableMapping

A dict-like store for Gemma4 key/value sharing that is safe to pass through FSDP2.

Why a plain dict breaks under FSDP2: With FSDP2 each decoder layer is wrapped as its own unit, and the default mixed-precision setting (cast_forward_inputs=True) makes FSDP2 look at every argument passed to a layer and cast its float tensors to bf16. It does this with torch’s _apply_to_tensors, which, whenever it sees a dict (or list/tuple/set/…), builds a brand-new copy of it. So if the shared store is a plain dict, each layer receives its own private copy: the earlier layer’s writes land in a copy that is thrown away, and the later layers read from an empty copy — which raises KeyError: 'sliding_attention'.

_store
dict[str, tuple[Tensor, Tensor]] = {}
nemo_automodel.components.models.gemma4_moe.model._FSDPSafeSharedKVStates.__delitem__(
key: str
) -> None
nemo_automodel.components.models.gemma4_moe.model._FSDPSafeSharedKVStates.__getitem__(
key: str
) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.models.gemma4_moe.model._FSDPSafeSharedKVStates.__iter__() -> typing.Iterator[str]
nemo_automodel.components.models.gemma4_moe.model._FSDPSafeSharedKVStates.__len__() -> int
nemo_automodel.components.models.gemma4_moe.model._FSDPSafeSharedKVStates.__setitem__(
key: str,
value: tuple[torch.Tensor, torch.Tensor]
) -> None
class nemo_automodel.components.models.gemma4_moe.model._Gemma4KVShareHolder()

Cache-free holder that lets HF gemma4 kv-sharing fire under use_cache=False.

E2B/E4B share K/V across the trailing num_kv_shared_layers layers: each shared layer reads its source layer’s K/V from past_key_values.shared_layers (see HF Gemma4Attention.forward). HF gates that read on past_key_values is not None, which is None whenever use_cache=False — and use_cache is forced off by activation checkpointing and the model-owned CP path. The shared layers then fall back to their (frozen, unused) K/V projections and produce garbage, inflating the loss ~4x.

Passing this lightweight object as past_key_values satisfies the gate so HF’s own kv-sharing logic runs: source layers populate shared_layers and shared layers read it. update is a pass-through (no per-token accumulation, so no cache memory growth), and get_seq_length returns 0 so the causal mask is built with a zero cache offset (correct for a training forward).

shared_layers
dict = {}
nemo_automodel.components.models.gemma4_moe.model._Gemma4KVShareHolder.get_mask_sizes(
query_length: int,
layer_idx = None
) -> tuple[int, int]
nemo_automodel.components.models.gemma4_moe.model._Gemma4KVShareHolder.get_seq_length(
args = (),
kwargs = {}
) -> int
nemo_automodel.components.models.gemma4_moe.model._Gemma4KVShareHolder.update(
key_states,
value_states,
layer_idx,
args = (),
kwargs = {}
)
nemo_automodel.components.models.gemma4_moe.model._build_packed_gemma4_causal_mask_mapping(
packed_seq_ids: torch.Tensor,
mm_token_type_ids: torch.Tensor,
dtype: torch.dtype,
sliding_window: int | None,
as_additive: bool = False,
as_block_mask: bool = False,
flex_block_size: int | tuple[int, int] = 128
) -> dict[str, torch.Tensor]

Build Gemma4 full/sliding masks for packed VLM sequences.

packed_seq_ids contains 1-based document ids and 0 for padding. Full-attention layers remain plain packed causal attention. Sliding layers also include Gemma4’s same-image-token bidirectional edges.

nemo_automodel.components.models.gemma4_moe.model._convert_bool_4d_mask_to_additive(
attention_mask: torch.Tensor,
dtype: torch.dtype
) -> torch.Tensor

Convert a 4D bool allowed-mask to HF additive format (0.0 allowed, -inf masked).

nemo_automodel.components.models.gemma4_moe.model._derive_padding_mask(
attention_mask: torch.Tensor
) -> torch.Tensor

Derive 2D padding mask (True = pad) from 1D, 2D, or 4D attention mask.

nemo_automodel.components.models.gemma4_moe.model._kv_sharing_active(
text_config
) -> bool

True if the (dense) text config uses gemma4 kv-sharing (E2B/E4B).

nemo_automodel.components.models.gemma4_moe.model._make_missing(
name: str
)
nemo_automodel.components.models.gemma4_moe.model.Gemma4Attention = getattr(_g4, 'Gemma4TextAttention', None) or _g4.Gemma4Attention
nemo_automodel.components.models.gemma4_moe.model.Gemma4CausalLMOutputWithPast = _g4.Gemma4CausalLMOutputWithPast
nemo_automodel.components.models.gemma4_moe.model.Gemma4DecoderLayer = getattr(_g4, 'Gemma4TextDecoderLayer', None) or _g4.Gemma4DecoderLayer
nemo_automodel.components.models.gemma4_moe.model.Gemma4MLP = getattr(_g4, 'Gemma4TextMLP', None) or _g4.Gemma4MLP
nemo_automodel.components.models.gemma4_moe.model.Gemma4RMSNorm = _g4.Gemma4RMSNorm
nemo_automodel.components.models.gemma4_moe.model.Gemma4RotaryEmbedding = getattr(_g4, 'Gemma4TextRotaryEmbedding', None) or _g4.Gemma4RotaryEmbedding
nemo_automodel.components.models.gemma4_moe.model.Gemma4TextModel = _g4.Gemma4TextModel
nemo_automodel.components.models.gemma4_moe.model.Gemma4TextScaledWordEmbedding = _g4.Gemma4TextScaledWordEmbedding
nemo_automodel.components.models.gemma4_moe.model.ModelClass = Gemma4ForConditionalGeneration
nemo_automodel.components.models.gemma4_moe.model._GEMMA4_HF_AVAILABLE = True