bridge.models.gemma.gemma4_provider#

Gemma 4 Model Provider for Megatron-Core.

Gemma 4 is a Mixture-of-Experts (MoE) model with hybrid sliding/global attention. Key differences from Gemma 3:

  • MoE: 128 experts, top-k=8, plus a dense MLP path (mapped to shared experts)

  • Heterogeneous attention: sliding layers use head_dim=256 / 8 KV heads, global layers use global_head_dim=512 / 2 KV heads with partial rotary (0.25)

  • K=V sharing on global attention layers (V projection may be omitted)

  • Per-layer scaling via layer_scalar buffer

  • Dual pre/post layernorms for dense MLP vs MoE paths

Module Contents#

Classes#

Gemma4ModelProvider

Configuration and provider for Megatron Core Gemma 4 models.

Gemma4TransformerLayer

Gemma 4 transformer layer with per-layer output scaling and extra post-norms.

Gemma4TopKRouter

Gemma 4 MoE router with per-expert scaling.

Gemma4MoELayer

Gemma 4 MoE layer with post-routed-expert and post-shared-expert normalization.

Gemma4OutputLayer

Mixin that applies final_logit_softcapping after the output linear layer.

Gemma4SelfAttention

Gemma 4 self attention with heterogeneous sliding/global layers.

Gemma4TEDotProductAttention

Gemma 4 core attention.

Gemma4RotaryEmbedding

Gemma 4 position RoPE embedding.

Functions#

_logit_softcapping

Prevents logits from growing excessively: scale * tanh(logits / scale).

_install_tied_kv

Mark global attention layers that require K=V weight tying.

_gemma4_block_spec

Build Gemma 4 block spec: MoE or dense layer specs with patched attention.

Data#

API#

bridge.models.gemma.gemma4_provider.HAVE_TE#

None

class bridge.models.gemma.gemma4_provider.Gemma4ModelProvider#

Bases: megatron.bridge.models.gpt_provider.GPTModelProvider

Configuration and provider for Megatron Core Gemma 4 models.

Gemma 4 is a MoE model with hybrid sliding/global attention. The dense MLP path is mapped to Megatron-Core’s shared expert mechanism.

seq_length: int#

262144

position_embedding_type: str#

‘rope’

rotary_base: tuple#

(10000, 1000000)

share_embeddings_and_output_weights: bool#

True

normalization: str#

‘RMSNorm’

layernorm_zero_centered_gamma: bool#

False

layernorm_epsilon: float#

1e-06

kv_channels: int#

256

num_query_groups: int#

8

window_size: int#

1024

interleaved_attn_pattern: tuple#

(5, 1)

attention_dropout: float#

0.0

hidden_dropout: float#

0.0

attention_backend: megatron.core.transformer.enums.AttnBackend#

None

softmax_scale: float#

1.0

qk_layernorm: bool#

True

attention_k_eq_v: bool#

False

global_head_dim: int#

512

num_global_key_value_heads: int#

2

global_rotary_percent: float#

0.25

gated_linear_unit: bool#

True

add_bias_linear: bool#

False

activation_func: Callable#

None

num_moe_experts: Optional[int]#

128

moe_router_topk: int#

8

moe_ffn_hidden_size: int#

704

moe_shared_expert_intermediate_size: int#

2112

moe_shared_expert_overlap: bool#

False

moe_shared_expert_gate: bool#

False

moe_grouped_gemm: bool#

True

moe_token_dispatcher_type: str#

‘alltoall’

moe_router_load_balancing_type: str#

‘aux_loss’

moe_router_pre_softmax: bool#

True

moe_router_dtype: str#

‘fp32’

moe_aux_loss_coeff: float#

0.001

moe_permute_fusion: bool#

True

moe_layer_freq: int#

1

final_logit_softcapping: float#

30.0

flash_decode: bool#

False

transformer_layer_spec: Union[Callable, object]#

‘field(…)’

scatter_embedding_sequence_parallel: bool#

True

bf16: bool#

True

fp16: bool#

False

params_dtype: torch.dtype#

None

autocast_dtype: torch.dtype#

None

provide(
pre_process=None,
post_process=None,
vp_stage=None,
) megatron.core.models.gpt.GPTModel#

Configure and instantiate a Megatron Core Gemma 4 model.

Replaces the model’s embedding and RoPE with customized Gemma 4 variants that handle embedding scaling and dual local/global RoPE.

class bridge.models.gemma.gemma4_provider.Gemma4TransformerLayer(config, submodules, layer_number=1, **kwargs)#

Bases: megatron.core.transformer.transformer_layer.TransformerLayer

Gemma 4 transformer layer with per-layer output scaling and extra post-norms.

Gemma 4 has architectural features not present in standard MCore:

  • layer_scalar: per-layer scaling applied to the full hidden state after residual add.

  • post_ffn_layernorm: norm applied to the combined dense+MoE output before residual add (HF’s post_feedforward_layernorm).

  • post_moe_layernorm: norm applied to routed expert output before combining with dense (HF’s post_feedforward_layernorm_2). Applied via a forward hook on the MoE layer.

Initialization

_forward_post_mlp(mlp_output_with_bias, residual)#

Override to apply post_ffn_layernorm before residual add, then layer_scalar.

class bridge.models.gemma.gemma4_provider.Gemma4TopKRouter(config, **kwargs)#

Bases: megatron.core.transformer.moe.router.TopKRouter

Gemma 4 MoE router with per-expert scaling.

Applies per_expert_scale to the routing probs after standard routing. Also renormalizes top-k weights before scaling (matching HF behavior).

The router’s input preprocessing (parameter-free RMSNorm + scale * scalar_root_size) is fused into the router weight at load time in the bridge.

Initialization

routing(logits, padding_mask=None, input_ids=None)#

Apply standard routing, then renormalize and scale by per_expert_scale.

class bridge.models.gemma.gemma4_provider.Gemma4MoELayer(config, submodules, **kwargs)#

Bases: megatron.core.transformer.moe.moe_layer.MoELayer

Gemma 4 MoE layer with post-routed-expert and post-shared-expert normalization.

Applies post_feedforward_layernorm_2 (pffl_ln2) to routed expert output and post_feedforward_layernorm_1 (pffl_ln1) to shared expert output before combining. Standard MCore MoELayer simply sums routed + shared outputs without any intermediate norms.

Initialization

postprocess(output, shared_expert_output)#

Apply post-MoE norms to routed and shared expert outputs, then combine.

bridge.models.gemma.gemma4_provider._logit_softcapping(
logits: torch.Tensor,
scale: float | None,
) torch.Tensor#

Prevents logits from growing excessively: scale * tanh(logits / scale).

class bridge.models.gemma.gemma4_provider.Gemma4OutputLayer#

Bases: torch.nn.Module

Mixin that applies final_logit_softcapping after the output linear layer.

forward(*args, **kwargs)#
bridge.models.gemma.gemma4_provider._install_tied_kv(
model: torch.nn.Module,
provider: bridge.models.gemma.gemma4_provider.Gemma4ModelProvider,
) None#

Mark global attention layers that require K=V weight tying.

In Gemma4, global attention layers share K and V projections (v_proj absent in the HF checkpoint). At import time the bridge copies K rows into the V rows of linear_qkv.weight. This function marks each global Gemma4SelfAttention module with _tied_kv = True so that

Meth:

Gemma4SelfAttention.get_query_key_value_tensors can enforce V=K in the forward pass.

K-V sharing is decided based on attention_k_eq_v field. Must be called after model construction so that the attention modules are already built.

Note on gradient routing for LoRA: since V-rows = K-rows in the loaded checkpoint, the forward pass is numerically correct without any further modification. Full gradient routing (accumulating dL/dV into K-rows) is left as a future improvement.

bridge.models.gemma.gemma4_provider._gemma4_block_spec(config, use_transformer_engine=True, **kwargs)#

Build Gemma 4 block spec: MoE or dense layer specs with patched attention.

Uses get_gpt_decoder_block_spec to build standard specs, then patches each layer spec:

  • Attention module → Gemma4SelfAttention (heterogeneous head dims)

  • Core attention → Gemma4TEDotProductAttention (sliding/global window)

  • linear_proj → TERowParallelLinearLayerNorm (post-attention RMSNorm)

  • MoE models only: MoE layer → Gemma4MoELayer, router → Gemma4TopKRouter

class bridge.models.gemma.gemma4_provider.Gemma4SelfAttention(
config: megatron.core.transformer.TransformerConfig,
layer_number: int,
**kwargs,
)#

Bases: megatron.core.transformer.attention.SelfAttention

Gemma 4 self attention with heterogeneous sliding/global layers.

  • Sliding layers: head_dim=256, num_kv_heads=8, full rotary, local window

  • Global layers: head_dim=512, num_kv_heads=2, partial rotary (0.25), full attention

  • Value normalization: parameter-free RMSNorm applied to V after projection

The config is deep-copied and overridden per-layer so that the QKV linear is constructed with the correct dimensions.

Initialization

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Override to separate sliding and global layers in the checkpoint.

Sliding layers (head_dim=256) and global layers (head_dim=512) produce linear_qkv, linear_proj, q_layernorm, k_layernorm tensors with different shapes. dist_checkpointing validates two things per key group:

  1. Uniform global_shape — fails because sliding/global shapes differ.

  2. Full coverage of the global tensor — fails if only a subset of layers fill the group (e.g. 25 sliding layers can’t cover a 30-slot group).

Fix: append ‘_sliding’/’_global’ suffix to create per-type groups AND remap the prepended layer axis in ShardedTensors so global_shape[0], global_offset[0], and axis_fragmentations[0] reflect per-type layer counts rather than the total layer count.

.. rubric:: Example

‘decoder.layers.0.self_attention.’

→ ‘decoder.layers.0.self_attention_sliding.’ (or _global) Loading works automatically because the same class produces the same suffixed keys on load.

get_query_key_value_tensors(
hidden_states,
key_value_states=None,
**kwargs,
)#

Override to apply parameter-free RMSNorm to V after QKV split.

HF Gemma4 applies v_norm = Gemma4RMSNorm(head_dim, with_scale=False) to the value states. This is a parameter-free normalization: v / rms(v).

For global attention layers (self._tied_kv = True), K=V tying is enforced here after super() has completed the all-gather for KV-replicated TP layouts. This ensures V=K throughout training for all tensor-parallel configs.

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
rotary_pos_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_bias: Optional[torch.Tensor] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
) Tuple[torch.Tensor, torch.Tensor]#

Switch to either local or global RoPE embedding before forward.

class bridge.models.gemma.gemma4_provider.Gemma4TEDotProductAttention(
config: megatron.core.transformer.TransformerConfig,
layer_number: int,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
attention_type: str,
attention_dropout: Optional[float] = None,
**kwargs,
)#

Bases: bridge.models.gemma.gemma4_provider.TEDotProductAttention

Gemma 4 core attention.

Switches between global and local sliding window attention based on the layer_number and pre-defined layer pattern.

Initialization

class bridge.models.gemma.gemma4_provider.Gemma4RotaryEmbedding(
rotary_base: int = 1000000,
rotary_base_local: int = 10000,
global_kv_channels: int = 512,
global_rotary_percent: float = 0.25,
**kwargs,
)#

Bases: megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding

Gemma 4 position RoPE embedding.

Computes RoPE embeddings for both local (sliding) and global (full) attention layers. Local layers use full rotary with theta=10000. Global layers use proportional partial rotary (0.25) with theta=1000000.

HF’s proportional RoPE formula differs from standard partial rotary:

  • Standard: inv_freq = 1/(base^(arange(0, dim, 2) / dim)) where dim = head_dim * percent

  • Proportional: inv_freq = 1/(base^(arange(0, dim, 2) / head_dim)) denominator is full head_dim

This gives slower-decaying frequencies (spread across the full head_dim range).

Initialization

forward(
max_seq_len: int,
offset: int = 0,
packed_seq: bool = False,
cp_group: torch.distributed.ProcessGroup | None = None,
) tuple[torch.Tensor, torch.Tensor]#

Get (local_rope, global_rope) tuple.

Local and global RoPE have different dimensions (e.g. 256 vs 64), so they cannot be stacked into a single tensor.

_forward_cached(
max_seq_len: int,
offset: int = 0,
packed_seq: bool = False,
) tuple[torch.Tensor, torch.Tensor]#

Cached forward for hashable parameters only.