nemo_automodel.components.models.gemma4_moe.model#
Gemma4 MoE NeMo Automodel support.
Replaces the HF-native Gemma4 MoE (dense matmul over all experts) with NeMoβs GroupedExperts backend, enabling Expert Parallelism (EP) via the standard MoE parallelizer.
Module Contents#
Classes#
Gemma4 Router reimplemented to output NeMo Gate format. |
|
NeMo MoE that uses Gemma4Gate (with pre-norm routing) instead of
the standard Gate. Subclasses MoE so that |
|
Gemma4 decoder layer with NeMo MoE backend. |
|
Gemma4 text decoder rebuilt with NeMo MoE blocks. |
|
Thin wrapper that exposes |
|
Gemma4 VL conditional generation model with NeMo MoE backend. |
Functions#
API#
- nemo_automodel.components.models.gemma4_moe.model._make_missing(name: str)#
- class nemo_automodel.components.models.gemma4_moe.model.Gemma4Gate(
- config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig,
Bases:
torch.nn.ModuleGemma4 Router reimplemented to output NeMo Gate format.
HF Gemma4Router applies: RMSNorm(no_scale) β root_size scaling β learnable scale β Linear β softmax β top-k β renormalize which is different from the standard Gate class in layer.py. This class reproduces that logic but returns (weights, indices, aux_loss) as expected by GroupedExperts.
Initialization
- forward(x, token_mask=None, cp_mesh=None)#
- init_weights(
- buffer_device: torch.device,
- init_std: float = 0.02,
- class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoE(
- moe_config: nemo_automodel.components.moe.layers.MoEConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- text_config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig,
Bases:
nemo_automodel.components.moe.layers.MoENeMo MoE that uses Gemma4Gate (with pre-norm routing) instead of the standard Gate. Subclasses MoE so that
isinstance(m, MoE)is True, which the EP parallelizer relies on.Initialization
Initializes the MoE module.
- Parameters:
args (MoEArgs) β Model arguments containing MoE parameters.
- class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoEDecoderLayer(
- config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig,
- layer_idx: int,
- moe_config: nemo_automodel.components.moe.layers.MoEConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
Bases:
torch.nn.ModuleGemma4 decoder layer with NeMo MoE backend.
Reuses HF attention and dense MLP, replaces HF Router+MoEBlock with NeMo Gemma4MoE (Gemma4Gate + GroupedExperts).
Initialization
- forward(
- x: torch.Tensor,
- *,
- position_embeddings: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- padding_mask: torch.Tensor | None = None,
- past_key_values=None,
- use_cache: bool | None = False,
- cache_position: torch.LongTensor | None = None,
- **kwargs: Any,
- class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoETextModelBackend(
- config: transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- *,
- moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
- moe_overrides: dict | None = None,
Bases:
torch.nn.ModuleGemma4 text decoder rebuilt with NeMo MoE blocks.
Initialization
- forward(
- input_ids: torch.Tensor | None = None,
- *,
- inputs_embeds: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- cache_position: torch.Tensor | None = None,
- padding_mask: torch.Tensor | None = None,
- past_key_values=None,
- use_cache: bool | None = None,
- **kwargs: Any,
- get_input_embeddings() torch.nn.Module#
- set_input_embeddings(value: torch.nn.Module) None#
- class nemo_automodel.components.models.gemma4_moe.model.Gemma4MoEModel#
Bases:
transformers.models.gemma4.modeling_gemma4.Gemma4ModelThin wrapper that exposes
language_modelinternals as properties expected by the NeMo training loop.- property layers#
- property embed_tokens#
- property norm#
- class nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration(
- config: transformers.models.gemma4.configuration_gemma4.Gemma4Config,
- moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
- backend: nemo_automodel.components.models.common.BackendConfig | None = None,
- text_config: dict | None = None,
- **kwargs,
Bases:
nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixin,transformers.models.gemma4.modeling_gemma4.Gemma4ForConditionalGeneration,nemo_automodel.components.moe.fsdp_mixin.MoEFSDPSyncMixinGemma4 VL conditional generation model with NeMo MoE backend.
When the checkpoint has
enable_moe_block=Truein its text config, replaces the HF-native language model withGemma4MoETextModelBackend(NeMo GroupedExperts + Gemma4Gate). Otherwise falls through to vanilla HF.Initialization
- classmethod from_config(
- config: transformers.models.gemma4.configuration_gemma4.Gemma4Config,
- moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
- backend: nemo_automodel.components.models.common.BackendConfig | None = None,
- **kwargs,
- classmethod from_pretrained(
- pretrained_model_name_or_path: str,
- *model_args,
- **kwargs,
- forward(
- input_ids: torch.Tensor | None = None,
- *,
- position_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- padding_mask: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- cache_position: torch.Tensor | None = None,
- pixel_values: torch.Tensor | None = None,
- image_position_ids: torch.Tensor | None = None,
- mm_token_type_ids: torch.Tensor | None = None,
- **kwargs: Any,
- initialize_weights(
- buffer_device: torch.device | None = None,
- dtype: torch.dtype = torch.bfloat16,