> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.models.gemma4_moe.model

Gemma4 MoE NeMo Automodel support.

Replaces the HF-native Gemma4 MoE (dense matmul over all experts) with NeMo's
GroupedExperts backend, enabling Expert Parallelism (EP) via the standard
MoE parallelizer.

## Module Contents

### Classes

| Name                                                                                                                  | Description                                                                        |
| --------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- |
| [`Gemma4ForConditionalGeneration`](#nemo_automodel-components-models-gemma4_moe-model-Gemma4ForConditionalGeneration) | -                                                                                  |
| [`Gemma4Gate`](#nemo_automodel-components-models-gemma4_moe-model-Gemma4Gate)                                         | Gemma4 Router reimplemented to output NeMo Gate format.                            |
| [`Gemma4MoE`](#nemo_automodel-components-models-gemma4_moe-model-Gemma4MoE)                                           | NeMo MoE that uses Gemma4Gate (with pre-norm routing) instead of                   |
| [`Gemma4MoEDecoderLayer`](#nemo_automodel-components-models-gemma4_moe-model-Gemma4MoEDecoderLayer)                   | Gemma4 decoder layer with NeMo MoE backend.                                        |
| [`Gemma4MoEModel`](#nemo_automodel-components-models-gemma4_moe-model-Gemma4MoEModel)                                 | Thin wrapper that exposes `language_model` internals as properties                 |
| [`Gemma4MoETextModelBackend`](#nemo_automodel-components-models-gemma4_moe-model-Gemma4MoETextModelBackend)           | Gemma4 text decoder rebuilt with NeMo MoE blocks.                                  |
| [`_FSDPSafeSharedKVStates`](#nemo_automodel-components-models-gemma4_moe-model-_FSDPSafeSharedKVStates)               | A dict-like store for Gemma4 key/value sharing that is safe to pass through FSDP2. |
| [`_Gemma4KVShareHolder`](#nemo_automodel-components-models-gemma4_moe-model-_Gemma4KVShareHolder)                     | Cache-free holder that lets HF gemma4 kv-sharing fire under `use_cache=False`.     |

### Functions

| Name                                                                                                                                      | Description                                                                      |
| ----------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------- |
| [`_build_packed_gemma4_causal_mask_mapping`](#nemo_automodel-components-models-gemma4_moe-model-_build_packed_gemma4_causal_mask_mapping) | Build Gemma4 full/sliding masks for packed VLM sequences.                        |
| [`_convert_bool_4d_mask_to_additive`](#nemo_automodel-components-models-gemma4_moe-model-_convert_bool_4d_mask_to_additive)               | Convert a 4D bool allowed-mask to HF additive format (0.0 allowed, -inf masked). |
| [`_derive_padding_mask`](#nemo_automodel-components-models-gemma4_moe-model-_derive_padding_mask)                                         | Derive 2D padding mask (True = pad) from 1D, 2D, or 4D attention mask.           |
| [`_kv_sharing_active`](#nemo_automodel-components-models-gemma4_moe-model-_kv_sharing_active)                                             | True if the (dense) text config uses gemma4 kv-sharing (E2B/E4B).                |
| [`_make_missing`](#nemo_automodel-components-models-gemma4_moe-model-_make_missing)                                                       | -                                                                                |

### Data

[`Gemma4Attention`](#nemo_automodel-components-models-gemma4_moe-model-Gemma4Attention)

[`Gemma4CausalLMOutputWithPast`](#nemo_automodel-components-models-gemma4_moe-model-Gemma4CausalLMOutputWithPast)

[`Gemma4DecoderLayer`](#nemo_automodel-components-models-gemma4_moe-model-Gemma4DecoderLayer)

[`Gemma4MLP`](#nemo_automodel-components-models-gemma4_moe-model-Gemma4MLP)

[`Gemma4RMSNorm`](#nemo_automodel-components-models-gemma4_moe-model-Gemma4RMSNorm)

[`Gemma4RotaryEmbedding`](#nemo_automodel-components-models-gemma4_moe-model-Gemma4RotaryEmbedding)

[`Gemma4TextModel`](#nemo_automodel-components-models-gemma4_moe-model-Gemma4TextModel)

[`Gemma4TextScaledWordEmbedding`](#nemo_automodel-components-models-gemma4_moe-model-Gemma4TextScaledWordEmbedding)

[`ModelClass`](#nemo_automodel-components-models-gemma4_moe-model-ModelClass)

[`_GEMMA4_HF_AVAILABLE`](#nemo_automodel-components-models-gemma4_moe-model-_GEMMA4_HF_AVAILABLE)

### API

```python
class nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration(
    config: transformers.models.gemma4.configuration_gemma4.Gemma4Config,
    moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
    backend: nemo_automodel.components.models.common.BackendConfig | None = None,
    text_config: dict | None = None,
    kwargs = {}
)
```

**Bases:** [HFCheckpointingMixin](/nemo-automodel/nemo_automodel/components/models/common/hf_checkpointing_mixin#nemo_automodel-components-models-common-hf_checkpointing_mixin-HFCheckpointingMixin), `HFGemma4ForConditionalGeneration`, [MoEFSDPSyncMixin](/nemo-automodel/nemo_automodel/components/moe/fsdp_mixin#nemo_automodel-components-moe-fsdp_mixin-MoEFSDPSyncMixin)

Gemma4 VL conditional generation model with NeMo MoE backend.

When the checkpoint has `enable_moe_block=True` in its text config,
replaces the HF-native language model with `Gemma4MoETextModelBackend`
(NeMo GroupedExperts + Gemma4Gate).  Otherwise falls through to vanilla HF.

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration._cp_shard_batch(
    cp_mesh,
    tp_mesh,
    batch,
    loss_mask = None,
    padding_token_id = 0
)
```

Gemma4-owned CP batch sharder that also self-installs the ring.

Attached to the batch as `_cp_make_batch_fn` by
`prepare_model_inputs_for_cp`. `cp_utils.make_cp_batch_and_ctx` calls it
with the CP submesh, which is the one place Gemma4 receives `cp_mesh` on a
model-owned path -- so install the ring here (idempotent) before sharding,
rather than depending on the framework to call `setup_cp_attention`.

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration._get_special_image_mask(
    input_ids: torch.Tensor,
    mm_token_type_ids: torch.Tensor | None = None
) -> torch.Tensor
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration._get_text_pad_token_id() -> int
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration._prepare_per_layer_inputs_for_cp(
    input_ids: torch.Tensor,
    special_image_mask: torch.Tensor
) -> torch.Tensor | None
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.forward(
    input_ids: torch.Tensor | None = None,
    position_ids: torch.Tensor | None = None,
    attention_mask: torch.Tensor | None = None,
    padding_mask: torch.Tensor | None = None,
    inputs_embeds: torch.Tensor | None = None,
    cache_position: torch.Tensor | None = None,
    pixel_values: torch.Tensor | None = None,
    image_position_ids: torch.Tensor | None = None,
    mm_token_type_ids: torch.Tensor | None = None,
    _pre_embed_only: bool = False,
    logits_to_keep: typing.Union[int, torch.Tensor] = 0,
    output_hidden_states: typing.Optional[bool] = None,
    kwargs: typing.Any = {}
)
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.from_config(
    config: transformers.models.gemma4.configuration_gemma4.Gemma4Config,
    moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
    backend: nemo_automodel.components.models.common.BackendConfig | None = None,
    kwargs = {}
)
```

classmethod

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.from_pretrained(
    pretrained_model_name_or_path: str,
    model_args = (),
    kwargs = {}
)
```

classmethod

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.get_capabilities(
    config: transformers.models.gemma4.configuration_gemma4.Gemma4Config
) -> nemo_automodel._transformers.model_capabilities.ModelCapabilities
```

classmethod

Return the capabilities for a specific config (no model instance needed).

Dispatches in two layers so the same class can serve every Gemma4
checkpoint honestly:

1. If `config.text_config.enable_moe_block` is True → MoE variant
   (e.g. `google/gemma-4-26B-A4B-it`).
2. Else if `config.audio_config` is not `None` → dense + audio
   variant (e.g. `google/gemma-4-E2B-it`, `google/gemma-4-E4B-it`).
3. Else → plain dense variant (e.g. `google/gemma-4-31B-it`).

**Parameters:**

The model's `Gemma4Config` (or anything exposing a
`text_config` with `enable_moe_block` and an
`audio_config` attribute).

**Returns:** `ModelCapabilities`

A populated :class:`ModelCapabilities` for this specific config.

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.initialize_weights(
    buffer_device: torch.device | None = None,
    dtype: torch.dtype = torch.bfloat16
) -> None
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.prepare_inputs_embeds_for_cp(
    input_ids: torch.Tensor,
    pixel_values: torch.Tensor | None = None,
    image_position_ids: torch.Tensor | None = None,
    mm_token_type_ids: torch.Tensor | None = None
) -> torch.Tensor
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.prepare_model_inputs_for_cp(
    input_ids: torch.Tensor,
    pixel_values: torch.Tensor | None = None,
    image_position_ids: torch.Tensor | None = None,
    mm_token_type_ids: torch.Tensor | None = None
) -> dict[str, typing.Any]
```

Prepare Gemma4 embeddings on the full sequence before CP sharding.

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.setup_cp_attention(
    cp_mesh
) -> None
```

Install Gemma4's model-owned p2p ring CP attention (dense path).

Idempotent: flips the `_cp_enabled` flag the forward reads and installs
the ring on every self-attn module (each was given a per-module
`setup_cp_attention` by `attach_gemma4_cp_ring_attention` at
construction). Invoked from Gemma4's own batch-sharding callable
(`_cp_shard_batch`) the first time the recipe hands it the CP submesh, so
the install is fully model-owned -- no framework dispatch is required.

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration.tie_weights(
    _args: object = (),
    _kwargs: object = {}
) -> None
```

Tie `lm_head` to the active text `embed_tokens` when requested.

Overrides HF's generic tying so that any caller after the MoE
`language_model` swap (construction, AutoModel, and checkpoint load
via `ensure_tied_lm_head`) re-points `lm_head` to the *active*
embedding rather than whatever HF's `get_input_embeddings()`
indirection resolves to. No-op when the config requests untied
embeddings.

Accepts and ignores positional/keyword arguments (e.g. HF v5's
`recompute_mapping`) so it stays drop-in compatible with the HF
`init_weights() -&gt; tie_weights(...)` call path.

The controlling flag is the top-level `Gemma4Config.tie_word_embeddings`
(verified against HF: the top-level flag decides tying regardless of the
nested `text_config` value), so read it first and only fall back to
`text_config` for configs that don't expose a top-level flag.

```python
class nemo_automodel.components.models.gemma4_moe.model.Gemma4Gate(
    config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig
)
```

**Bases:** `Module`

Gemma4 Router reimplemented to output NeMo Gate format.

HF Gemma4Router applies: RMSNorm(no\_scale) → root\_size scaling → learnable
scale → Linear → softmax → top-k → renormalize which is different from the standard Gate class in layer.py.
This class reproduces that logic but returns (weights, indices, aux\_loss) as expected by GroupedExperts.

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4Gate.forward(
    x,
    token_mask = None,
    cp_mesh = None
)
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4Gate.init_weights(
    buffer_device: torch.device,
    init_std: float = 0.02
) -> None
```

```python
class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoE(
    moe_config: nemo_automodel.components.moe.layers.MoEConfig,
    backend: nemo_automodel.components.models.common.BackendConfig,
    text_config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig
)
```

**Bases:** [MoE](/nemo-automodel/nemo_automodel/components/moe/layers#nemo_automodel-components-moe-layers-MoE)

NeMo MoE that uses Gemma4Gate (with pre-norm routing) instead of
the standard Gate. Subclasses MoE so that `isinstance(m, MoE)` is True,
which the EP parallelizer relies on.

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4MoE.forward(
    x,
    padding_mask = None,
    cp_mesh = None,
    gate_input = None
)
```

Forward with optional separate gate input.

HF Gemma4 passes unnormalized residual to the router and normalized
input to the experts.  The decoder layer calls this with
`gate_input=x` (raw residual) so the gate receives unnormalized
input while experts receive `pre_feedforward_layernorm_2(x)`.

```python
class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoEDecoderLayer(
    config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig,
    layer_idx: int,
    moe_config: nemo_automodel.components.moe.layers.MoEConfig,
    backend: nemo_automodel.components.models.common.BackendConfig
)
```

**Bases:** `Module`

Gemma4 decoder layer with NeMo MoE backend.

Reuses HF attention and dense MLP, replaces HF Router+MoEBlock with
NeMo Gemma4MoE (Gemma4Gate + GroupedExperts).

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4MoEDecoderLayer.forward(
    x: torch.Tensor,
    position_embeddings: torch.Tensor,
    attention_mask: torch.Tensor | None = None,
    position_ids: torch.LongTensor | None = None,
    padding_mask: torch.Tensor | None = None,
    past_key_values = None,
    use_cache: bool | None = False,
    cache_position: torch.LongTensor | None = None,
    mm_token_type_ids: torch.Tensor | None = None,
    shared_kv_states: dict[str, tuple[torch.Tensor, torch.Tensor]] | None = None,
    kwargs: typing.Any = {}
) -> torch.Tensor
```

```python
class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoEModel()
```

**Bases:** `HFGemma4Model`

Thin wrapper that exposes `language_model` internals as properties
expected by the NeMo training loop.

```python
class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoETextModelBackend(
    config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig,
    backend: nemo_automodel.components.models.common.BackendConfig,
    moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
    moe_overrides: dict | None = None
)
```

**Bases:** `Module`

Gemma4 text decoder rebuilt with NeMo MoE blocks.

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4MoETextModelBackend.forward(
    input_ids: torch.Tensor | None = None,
    inputs_embeds: torch.Tensor | None = None,
    attention_mask: torch.Tensor | None = None,
    position_ids: torch.Tensor | None = None,
    cache_position: torch.Tensor | None = None,
    padding_mask: torch.Tensor | None = None,
    mm_token_type_ids: torch.Tensor | None = None,
    pixel_values: torch.Tensor | None = None,
    past_key_values = None,
    use_cache: bool | None = None,
    cp_enabled: bool = False,
    kwargs: typing.Any = {}
) -> transformers.modeling_outputs.BaseModelOutputWithPast
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4MoETextModelBackend.get_input_embeddings() -> torch.nn.Module
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4MoETextModelBackend.set_input_embeddings(
    value: torch.nn.Module
) -> None
```

```python
class nemo_automodel.components.models.gemma4_moe.model._FSDPSafeSharedKVStates()
```

**Bases:** `MutableMapping`

A dict-like store for Gemma4 key/value sharing that is safe to pass through FSDP2.

Why a plain `dict` breaks under FSDP2:
With FSDP2 each decoder layer is wrapped as its own unit, and the default
mixed-precision setting (`cast_forward_inputs=True`) makes FSDP2 look at
every argument passed to a layer and cast its float tensors to bf16. It
does this with torch's `_apply_to_tensors`, which, whenever it sees a
`dict` (or `list`/`tuple`/`set`/...), builds a brand-new copy of
it. So if the shared store is a plain `dict`, each layer receives its
own private copy: the earlier layer's writes land in a copy that is thrown
away, and the later layers read from an empty copy -- which raises
`KeyError: 'sliding_attention'`.

```python
nemo_automodel.components.models.gemma4_moe.model._FSDPSafeSharedKVStates.__delitem__(
    key: str
) -> None
```

```python
nemo_automodel.components.models.gemma4_moe.model._FSDPSafeSharedKVStates.__getitem__(
    key: str
) -> tuple[torch.Tensor, torch.Tensor]
```

```python
nemo_automodel.components.models.gemma4_moe.model._FSDPSafeSharedKVStates.__iter__() -> typing.Iterator[str]
```

```python
nemo_automodel.components.models.gemma4_moe.model._FSDPSafeSharedKVStates.__len__() -> int
```

```python
nemo_automodel.components.models.gemma4_moe.model._FSDPSafeSharedKVStates.__setitem__(
    key: str,
    value: tuple[torch.Tensor, torch.Tensor]
) -> None
```

```python
class nemo_automodel.components.models.gemma4_moe.model._Gemma4KVShareHolder()
```

Cache-free holder that lets HF gemma4 kv-sharing fire under `use_cache=False`.

E2B/E4B share K/V across the trailing `num_kv_shared_layers` layers: each shared
layer reads its source layer's K/V from `past_key_values.shared_layers` (see HF
`Gemma4Attention.forward`). HF gates that read on `past_key_values is not None`,
which is `None` whenever `use_cache=False` -- and `use_cache` is forced off by
activation checkpointing and the model-owned CP path. The shared layers then fall
back to their (frozen, unused) K/V projections and produce garbage, inflating the
loss \~4x.

Passing this lightweight object as `past_key_values` satisfies the gate so HF's
own kv-sharing logic runs: source layers populate `shared_layers` and shared
layers read it. `update` is a pass-through (no per-token accumulation, so no cache
memory growth), and `get_seq_length` returns 0 so the causal mask is built with a
zero cache offset (correct for a training forward).

```python
nemo_automodel.components.models.gemma4_moe.model._Gemma4KVShareHolder.get_mask_sizes(
    query_length: int,
    layer_idx = None
) -> tuple[int, int]
```

```python
nemo_automodel.components.models.gemma4_moe.model._Gemma4KVShareHolder.get_seq_length(
    args = (),
    kwargs = {}
) -> int
```

```python
nemo_automodel.components.models.gemma4_moe.model._Gemma4KVShareHolder.update(
    key_states,
    value_states,
    layer_idx,
    args = (),
    kwargs = {}
)
```

```python
nemo_automodel.components.models.gemma4_moe.model._build_packed_gemma4_causal_mask_mapping(
    packed_seq_ids: torch.Tensor,
    mm_token_type_ids: torch.Tensor,
    dtype: torch.dtype,
    sliding_window: int | None,
    as_additive: bool = False,
    as_block_mask: bool = False,
    flex_block_size: int | tuple[int, int] = 128
) -> dict[str, torch.Tensor]
```

Build Gemma4 full/sliding masks for packed VLM sequences.

`packed_seq_ids` contains 1-based document ids and 0 for padding.
Full-attention layers remain plain packed causal attention. Sliding layers
also include Gemma4's same-image-token bidirectional edges.

```python
nemo_automodel.components.models.gemma4_moe.model._convert_bool_4d_mask_to_additive(
    attention_mask: torch.Tensor,
    dtype: torch.dtype
) -> torch.Tensor
```

Convert a 4D bool allowed-mask to HF additive format (0.0 allowed, -inf masked).

```python
nemo_automodel.components.models.gemma4_moe.model._derive_padding_mask(
    attention_mask: torch.Tensor
) -> torch.Tensor
```

Derive 2D padding mask (True = pad) from 1D, 2D, or 4D attention mask.

```python
nemo_automodel.components.models.gemma4_moe.model._kv_sharing_active(
    text_config
) -> bool
```

True if the (dense) text config uses gemma4 kv-sharing (E2B/E4B).

```python
nemo_automodel.components.models.gemma4_moe.model._make_missing(
    name: str
)
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4Attention = getattr(_g4, 'Gemma4TextAttention', None) or _g4.Gemma4Attention
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4CausalLMOutputWithPast = _g4.Gemma4CausalLMOutputWithPast
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4DecoderLayer = getattr(_g4, 'Gemma4TextDecoderLayer', None) or _g4.Gemma4DecoderLayer
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4MLP = getattr(_g4, 'Gemma4TextMLP', None) or _g4.Gemma4MLP
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4RMSNorm = _g4.Gemma4RMSNorm
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4RotaryEmbedding = getattr(_g4, 'Gemma4TextRotaryEmbedding', None) or _g4.Gemma4RotaryEmbedding
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4TextModel = _g4.Gemma4TextModel
```

```python
nemo_automodel.components.models.gemma4_moe.model.Gemma4TextScaledWordEmbedding = _g4.Gemma4TextScaledWordEmbedding
```

```python
nemo_automodel.components.models.gemma4_moe.model.ModelClass = Gemma4ForConditionalGeneration
```

```python
nemo_automodel.components.models.gemma4_moe.model._GEMMA4_HF_AVAILABLE = True
```