bridge.models.gemma.gemma4_provider#
Gemma 4 Model Provider for Megatron-Core.
Gemma 4 is a Mixture-of-Experts (MoE) model with hybrid sliding/global attention. Key differences from Gemma 3:
MoE: 128 experts, top-k=8, plus a dense MLP path (mapped to shared experts)
Heterogeneous attention: sliding layers use head_dim=256 / 8 KV heads, global layers use global_head_dim=512 / 2 KV heads with partial rotary (0.25)
K=V sharing on global attention layers (V projection may be omitted)
Per-layer scaling via
layer_scalarbufferDual pre/post layernorms for dense MLP vs MoE paths
Module Contents#
Classes#
Configuration and provider for Megatron Core Gemma 4 models. |
|
Gemma 4 transformer layer with per-layer output scaling and extra post-norms. |
|
Gemma 4 MoE router with per-expert scaling. |
|
Gemma 4 MoE layer with post-routed-expert and post-shared-expert normalization. |
|
Mixin that applies final_logit_softcapping after the output linear layer. |
|
Gemma 4 self attention with heterogeneous sliding/global layers. |
|
Gemma 4 core attention. |
|
Gemma 4 position RoPE embedding. |
Functions#
Prevents logits from growing excessively: scale * tanh(logits / scale). |
|
Mark global attention layers that require K=V weight tying. |
|
Build Gemma 4 block spec: MoE or dense layer specs with patched attention. |
Data#
API#
- bridge.models.gemma.gemma4_provider.HAVE_TE#
None
- class bridge.models.gemma.gemma4_provider.Gemma4ModelProvider#
Bases:
megatron.bridge.models.gpt_provider.GPTModelProviderConfiguration and provider for Megatron Core Gemma 4 models.
Gemma 4 is a MoE model with hybrid sliding/global attention. The dense MLP path is mapped to Megatron-Core’s shared expert mechanism.
- seq_length: int#
262144
- position_embedding_type: str#
‘rope’
- rotary_base: tuple#
(10000, 1000000)
True
- normalization: str#
‘RMSNorm’
- layernorm_zero_centered_gamma: bool#
False
- layernorm_epsilon: float#
1e-06
- kv_channels: int#
256
- num_query_groups: int#
8
- window_size: int#
1024
- interleaved_attn_pattern: tuple#
(5, 1)
- attention_dropout: float#
0.0
0.0
- attention_backend: megatron.core.transformer.enums.AttnBackend#
None
- softmax_scale: float#
1.0
- qk_layernorm: bool#
True
- attention_k_eq_v: bool#
False
- global_head_dim: int#
512
- num_global_key_value_heads: int#
2
- global_rotary_percent: float#
0.25
- gated_linear_unit: bool#
True
- add_bias_linear: bool#
False
- activation_func: Callable#
None
- num_moe_experts: Optional[int]#
128
- moe_router_topk: int#
8
704
2112
False
False
- moe_grouped_gemm: bool#
True
- moe_token_dispatcher_type: str#
‘alltoall’
- moe_router_load_balancing_type: str#
‘aux_loss’
- moe_router_pre_softmax: bool#
True
- moe_router_dtype: str#
‘fp32’
- moe_aux_loss_coeff: float#
0.001
- moe_permute_fusion: bool#
True
- moe_layer_freq: int#
1
- final_logit_softcapping: float#
30.0
- flash_decode: bool#
False
- transformer_layer_spec: Union[Callable, object]#
‘field(…)’
- scatter_embedding_sequence_parallel: bool#
True
- bf16: bool#
True
- fp16: bool#
False
- params_dtype: torch.dtype#
None
- autocast_dtype: torch.dtype#
None
- provide(
- pre_process=None,
- post_process=None,
- vp_stage=None,
Configure and instantiate a Megatron Core Gemma 4 model.
Replaces the model’s embedding and RoPE with customized Gemma 4 variants that handle embedding scaling and dual local/global RoPE.
- class bridge.models.gemma.gemma4_provider.Gemma4TransformerLayer(config, submodules, layer_number=1, **kwargs)#
Bases:
megatron.core.transformer.transformer_layer.TransformerLayerGemma 4 transformer layer with per-layer output scaling and extra post-norms.
Gemma 4 has architectural features not present in standard MCore:
layer_scalar: per-layer scaling applied to the full hidden state after residual add.post_ffn_layernorm: norm applied to the combined dense+MoE output before residual add (HF’spost_feedforward_layernorm).post_moe_layernorm: norm applied to routed expert output before combining with dense (HF’spost_feedforward_layernorm_2). Applied via a forward hook on the MoE layer.
Initialization
- _forward_post_mlp(mlp_output_with_bias, residual)#
Override to apply post_ffn_layernorm before residual add, then layer_scalar.
- class bridge.models.gemma.gemma4_provider.Gemma4TopKRouter(config, **kwargs)#
Bases:
megatron.core.transformer.moe.router.TopKRouterGemma 4 MoE router with per-expert scaling.
Applies
per_expert_scaleto the routing probs after standard routing. Also renormalizes top-k weights before scaling (matching HF behavior).The router’s input preprocessing (parameter-free RMSNorm +
scale * scalar_root_size) is fused into the router weight at load time in the bridge.Initialization
- routing(logits, padding_mask=None, input_ids=None)#
Apply standard routing, then renormalize and scale by per_expert_scale.
- class bridge.models.gemma.gemma4_provider.Gemma4MoELayer(config, submodules, **kwargs)#
Bases:
megatron.core.transformer.moe.moe_layer.MoELayerGemma 4 MoE layer with post-routed-expert and post-shared-expert normalization.
Applies
post_feedforward_layernorm_2(pffl_ln2) to routed expert output andpost_feedforward_layernorm_1(pffl_ln1) to shared expert output before combining. Standard MCore MoELayer simply sums routed + shared outputs without any intermediate norms.Initialization
- postprocess(output, shared_expert_output)#
Apply post-MoE norms to routed and shared expert outputs, then combine.
- bridge.models.gemma.gemma4_provider._logit_softcapping(
- logits: torch.Tensor,
- scale: float | None,
Prevents logits from growing excessively: scale * tanh(logits / scale).
- class bridge.models.gemma.gemma4_provider.Gemma4OutputLayer#
Bases:
torch.nn.ModuleMixin that applies final_logit_softcapping after the output linear layer.
- forward(*args, **kwargs)#
- bridge.models.gemma.gemma4_provider._install_tied_kv(
- model: torch.nn.Module,
- provider: bridge.models.gemma.gemma4_provider.Gemma4ModelProvider,
Mark global attention layers that require K=V weight tying.
In Gemma4, global attention layers share K and V projections (
v_projabsent in the HF checkpoint). At import time the bridge copies K rows into the V rows oflinear_qkv.weight. This function marks each globalGemma4SelfAttentionmodule with_tied_kv = Trueso that- Meth:
Gemma4SelfAttention.get_query_key_value_tensorscan enforce V=K in the forward pass.
K-V sharing is decided based on attention_k_eq_v field. Must be called after model construction so that the attention modules are already built.
Note on gradient routing for LoRA: since V-rows = K-rows in the loaded checkpoint, the forward pass is numerically correct without any further modification. Full gradient routing (accumulating dL/dV into K-rows) is left as a future improvement.
- bridge.models.gemma.gemma4_provider._gemma4_block_spec(config, use_transformer_engine=True, **kwargs)#
Build Gemma 4 block spec: MoE or dense layer specs with patched attention.
Uses
get_gpt_decoder_block_specto build standard specs, then patches each layer spec:Attention module → Gemma4SelfAttention (heterogeneous head dims)
Core attention → Gemma4TEDotProductAttention (sliding/global window)
linear_proj → TERowParallelLinearLayerNorm (post-attention RMSNorm)
MoE models only: MoE layer → Gemma4MoELayer, router → Gemma4TopKRouter
- class bridge.models.gemma.gemma4_provider.Gemma4SelfAttention(
- config: megatron.core.transformer.TransformerConfig,
- layer_number: int,
- **kwargs,
Bases:
megatron.core.transformer.attention.SelfAttentionGemma 4 self attention with heterogeneous sliding/global layers.
Sliding layers: head_dim=256, num_kv_heads=8, full rotary, local window
Global layers: head_dim=512, num_kv_heads=2, partial rotary (0.25), full attention
Value normalization: parameter-free RMSNorm applied to V after projection
The config is deep-copied and overridden per-layer so that the QKV linear is constructed with the correct dimensions.
Initialization
- sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#
Override to separate sliding and global layers in the checkpoint.
Sliding layers (head_dim=256) and global layers (head_dim=512) produce linear_qkv, linear_proj, q_layernorm, k_layernorm tensors with different shapes. dist_checkpointing validates two things per key group:
Uniform global_shape — fails because sliding/global shapes differ.
Full coverage of the global tensor — fails if only a subset of layers fill the group (e.g. 25 sliding layers can’t cover a 30-slot group).
Fix: append ‘_sliding’/’_global’ suffix to create per-type groups AND remap the prepended layer axis in ShardedTensors so global_shape[0], global_offset[0], and axis_fragmentations[0] reflect per-type layer counts rather than the total layer count.
.. rubric:: Example
‘decoder.layers.0.self_attention.’
→ ‘decoder.layers.0.self_attention_sliding.’ (or _global) Loading works automatically because the same class produces the same suffixed keys on load.
- get_query_key_value_tensors(
- hidden_states,
- key_value_states=None,
- **kwargs,
Override to apply parameter-free RMSNorm to V after QKV split.
HF Gemma4 applies
v_norm = Gemma4RMSNorm(head_dim, with_scale=False)to the value states. This is a parameter-free normalization:v / rms(v).For global attention layers (
self._tied_kv = True), K=V tying is enforced here aftersuper()has completed the all-gather for KV-replicated TP layouts. This ensures V=K throughout training for all tensor-parallel configs.
- forward(
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- rotary_pos_emb: Optional[torch.Tensor] = None,
- rotary_pos_cos: Optional[torch.Tensor] = None,
- rotary_pos_sin: Optional[torch.Tensor] = None,
- rotary_pos_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- attention_bias: Optional[torch.Tensor] = None,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
- sequence_len_offset: Optional[int] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
Switch to either local or global RoPE embedding before forward.
- class bridge.models.gemma.gemma4_provider.Gemma4TEDotProductAttention(
- config: megatron.core.transformer.TransformerConfig,
- layer_number: int,
- attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
- attention_type: str,
- attention_dropout: Optional[float] = None,
- **kwargs,
Bases:
bridge.models.gemma.gemma4_provider.TEDotProductAttentionGemma 4 core attention.
Switches between global and local sliding window attention based on the layer_number and pre-defined layer pattern.
Initialization
- class bridge.models.gemma.gemma4_provider.Gemma4RotaryEmbedding(
- rotary_base: int = 1000000,
- rotary_base_local: int = 10000,
- global_kv_channels: int = 512,
- global_rotary_percent: float = 0.25,
- **kwargs,
Bases:
megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbeddingGemma 4 position RoPE embedding.
Computes RoPE embeddings for both local (sliding) and global (full) attention layers. Local layers use full rotary with theta=10000. Global layers use proportional partial rotary (0.25) with theta=1000000.
HF’s proportional RoPE formula differs from standard partial rotary:
Standard: inv_freq = 1/(base^(arange(0, dim, 2) / dim)) where dim = head_dim * percent
Proportional: inv_freq = 1/(base^(arange(0, dim, 2) / head_dim)) denominator is full head_dim
This gives slower-decaying frequencies (spread across the full head_dim range).
Initialization
- forward(
- max_seq_len: int,
- offset: int = 0,
- packed_seq: bool = False,
- cp_group: torch.distributed.ProcessGroup | None = None,
Get (local_rope, global_rope) tuple.
Local and global RoPE have different dimensions (e.g. 256 vs 64), so they cannot be stacked into a single tensor.
- _forward_cached(
- max_seq_len: int,
- offset: int = 0,
- packed_seq: bool = False,
Cached forward for hashable parameters only.