nemo_automodel.components.models.gemma4_moe.model#

Gemma4 MoE NeMo Automodel support.

Replaces the HF-native Gemma4 MoE (dense matmul over all experts) with NeMo’s GroupedExperts backend, enabling Expert Parallelism (EP) via the standard MoE parallelizer.

Module Contents#

Classes#

Gemma4Gate

Gemma4 Router reimplemented to output NeMo Gate format.

Gemma4MoE

NeMo MoE that uses Gemma4Gate (with pre-norm routing) instead of the standard Gate. Subclasses MoE so that isinstance(m, MoE) is True, which the EP parallelizer relies on.

Gemma4MoEDecoderLayer

Gemma4 decoder layer with NeMo MoE backend.

Gemma4MoETextModelBackend

Gemma4 text decoder rebuilt with NeMo MoE blocks.

Gemma4MoEModel

Thin wrapper that exposes language_model internals as properties expected by the NeMo training loop.

Gemma4ForConditionalGeneration

Gemma4 VL conditional generation model with NeMo MoE backend.

Functions#

API#

nemo_automodel.components.models.gemma4_moe.model._make_missing(name: str)#
class nemo_automodel.components.models.gemma4_moe.model.Gemma4Gate(
config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig,
)#

Bases: torch.nn.Module

Gemma4 Router reimplemented to output NeMo Gate format.

HF Gemma4Router applies: RMSNorm(no_scale) β†’ root_size scaling β†’ learnable scale β†’ Linear β†’ softmax β†’ top-k β†’ renormalize which is different from the standard Gate class in layer.py. This class reproduces that logic but returns (weights, indices, aux_loss) as expected by GroupedExperts.

Initialization

forward(x, token_mask=None, cp_mesh=None)#
init_weights(
buffer_device: torch.device,
init_std: float = 0.02,
) None#
class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoE(
moe_config: nemo_automodel.components.moe.layers.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
text_config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig,
)#

Bases: nemo_automodel.components.moe.layers.MoE

NeMo MoE that uses Gemma4Gate (with pre-norm routing) instead of the standard Gate. Subclasses MoE so that isinstance(m, MoE) is True, which the EP parallelizer relies on.

Initialization

Initializes the MoE module.

Parameters:

args (MoEArgs) – Model arguments containing MoE parameters.

class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoEDecoderLayer(
config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig,
layer_idx: int,
moe_config: nemo_automodel.components.moe.layers.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
)#

Bases: torch.nn.Module

Gemma4 decoder layer with NeMo MoE backend.

Reuses HF attention and dense MLP, replaces HF Router+MoEBlock with NeMo Gemma4MoE (Gemma4Gate + GroupedExperts).

Initialization

forward(
x: torch.Tensor,
*,
position_embeddings: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
padding_mask: torch.Tensor | None = None,
past_key_values=None,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
**kwargs: Any,
) torch.Tensor#
class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoETextModelBackend(
config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
*,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
moe_overrides: dict | None = None,
)#

Bases: torch.nn.Module

Gemma4 text decoder rebuilt with NeMo MoE blocks.

Initialization

forward(
input_ids: torch.Tensor | None = None,
*,
inputs_embeds: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
cache_position: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
past_key_values=None,
use_cache: bool | None = None,
**kwargs: Any,
) transformers.modeling_outputs.BaseModelOutputWithPast#
get_input_embeddings() torch.nn.Module#
set_input_embeddings(value: torch.nn.Module) None#
class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoEModel#

Bases: transformers.models.gemma4.modeling_gemma4.Gemma4Model

Thin wrapper that exposes language_model internals as properties expected by the NeMo training loop.

property layers#
property embed_tokens#
property norm#
class nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration(
config: transformers.models.gemma4.configuration_gemma4.Gemma4Config,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
text_config: dict | None = None,
**kwargs,
)#

Bases: nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixin, transformers.models.gemma4.modeling_gemma4.Gemma4ForConditionalGeneration, nemo_automodel.components.moe.fsdp_mixin.MoEFSDPSyncMixin

Gemma4 VL conditional generation model with NeMo MoE backend.

When the checkpoint has enable_moe_block=True in its text config, replaces the HF-native language model with Gemma4MoETextModelBackend (NeMo GroupedExperts + Gemma4Gate). Otherwise falls through to vanilla HF.

Initialization

classmethod from_config(
config: transformers.models.gemma4.configuration_gemma4.Gemma4Config,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
**kwargs,
)#
classmethod from_pretrained(
pretrained_model_name_or_path: str,
*model_args,
**kwargs,
)#
forward(
input_ids: torch.Tensor | None = None,
*,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
cache_position: torch.Tensor | None = None,
pixel_values: torch.Tensor | None = None,
image_position_ids: torch.Tensor | None = None,
mm_token_type_ids: torch.Tensor | None = None,
**kwargs: Any,
)#
initialize_weights(
buffer_device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16,
) None#