bridge.models.gemma.gemma4_bridge#

Megatron Bridge for Gemma 4 text-only (CausalLM).

Gemma 4 is a MoE model with hybrid sliding/global attention. The dense MLP is mapped to Megatron-Core’s shared expert mechanism, and routed experts use fused tensor format [num_experts, 2*intermediate, hidden].

Key architecture-specific handling:

  • K=V on global attention layers: v_proj is absent; K weights are copied to V.

  • Dual pre-norms: separate norms for dense MLP vs routed experts.

  • Router scale/per_expert_scale: loaded as replicated buffers.

  • layer_scalar: per-layer scaling buffer.

Supported models

  • google/gemma-4-26B-A4B (MoE, enable_moe_block=True) — fully supported.

NOT supported

  • Dense Gemma 4 models (enable_moe_block=False, e.g. google/gemma-4-e2b-it). gemma4_vl_bridge.py raises ValueError for non-MoE models. Dense support requires per-layer ffn_hidden_size and Per-Layer Embeddings (PLE) in MCore.

Module Contents#

Classes#

_Gemma4QKVMapping

QKV mapping that tolerates missing v_proj in the HF checkpoint.

Gemma4Bridge

Megatron Bridge for Gemma 4 text-only (CausalLM).

Functions#

_infer_attn_pattern

Infer (sliding, global) interleaved attention pattern from layer_types list.

API#

class bridge.models.gemma.gemma4_bridge._Gemma4QKVMapping(*args, **kwargs)#

Bases: megatron.bridge.models.conversion.param_mapping.QKVMapping

QKV mapping that tolerates missing v_proj in the HF checkpoint.

Gemma 4 global attention layers share K=V, so v_proj is absent. allow_hf_name_mismatch = True prevents the weight loader from skipping the entire QKV mapping; the V weights are synthesized from K in Gemma4Bridge.maybe_modify_loaded_hf_weight.

Initialization

class bridge.models.gemma.gemma4_bridge.Gemma4Bridge#

Bases: megatron.bridge.models.conversion.model_bridge.MegatronModelBridge

Megatron Bridge for Gemma 4 text-only (CausalLM).

Handles conversion between HuggingFace Gemma4ForCausalLM and Megatron-Core GPTModel with MoE + shared experts.

Architecture mapping:

  • Dense MLP → Megatron shared experts (moe_shared_expert_overlap=False)

  • Routed MoE → Megatron routed experts (fused expert format)

  • Sliding attention → standard kv_channels/num_query_groups

  • Global attention → overridden kv_channels/num_query_groups per layer

.. rubric:: Example

from megatron.bridge import AutoBridge bridge = AutoBridge.from_hf_pretrained(“google/gemma-4-12B-A2B”) provider = bridge.to_megatron_provider()

provider_bridge(
hf_pretrained: megatron.bridge.models.hf_pretrained.causal_lm.PreTrainedCausalLM,
) megatron.bridge.models.gemma.gemma4_provider.Gemma4ModelProvider#

Convert HuggingFace config to Gemma4ModelProvider.

maybe_modify_converted_hf_weight(
task,
converted_weights_dict,
hf_state_dict,
)#

Un-fuse fused weights and drop synthesized keys on export.

On import, two non-trivial fusions are applied to the MoE layers:

  1. Router fusion: mg = hf * (scale * hidden^-0.5 / pffl2)

  2. Shared-expert gate/up fusion: mg = hf * (pffl / pffl2)

This method inverts both fusions on export so the resulting HF weights exactly match the original checkpoint. It also drops the synthesized v_proj key produced for K=V global-attention layers where v_proj is absent in HF.

maybe_modify_loaded_hf_weight(
hf_param: str | dict[str, str],
hf_state_dict: Mapping[str, torch.Tensor],
) torch.Tensor#

Handle special weight loading for Gemma 4.

  1. K=V on global attention layers: synthesize v_proj from k_proj.

  2. Router weight fusion: absorb router.scale * scalar_root_size / (1 + ln2_weight) into router.proj.weight so MCore’s router produces correct logits when receiving pre_feedforward_layernorm_2-normed input.

  3. Shared expert pre-norm fusion: absorb the ratio (1 + pre_feedforward_layernorm) / (1 + pre_feedforward_layernorm_2) into shared expert gate/up weights so the shared expert effectively receives pre_feedforward_layernorm-normed input even though MCore feeds it pre_feedforward_layernorm_2-normed input.

_fuse_router_weight(
hf_param: str,
hf_state_dict: Mapping[str, torch.Tensor],
) torch.Tensor#

Fuse router preprocessing into projection weight.

HF router: logits = proj(rms_norm(x) * scale * scalar_root_size) MCore router: logits = weight @ pre_feedforward_layernorm_2(x)

Since rms_norm(x) = pre_feedforward_layernorm_2(x) / ln2_weight (Gemma 4 uses standard gamma: x * w / rms(x)), we fuse: new_weight = proj.weight * (scale * scalar_root_size / ln2_weight)

_fuse_shared_expert_prenorm(
hf_param: dict[str, str],
hf_state_dict: Mapping[str, torch.Tensor],
) dict[str, torch.Tensor]#

Fuse pre-norm correction into shared expert gate/up weights.

MCore feeds shared experts pre_feedforward_layernorm_2(x) but HF feeds them pre_feedforward_layernorm(x). Since both norms are standard RMSNorm (x * w / rms(x)), the correction is element-wise:

correction[j] = w_pffl[j] / w_pffl2[j]
new_weight[i, j] = weight[i, j] * correction[j]
mapping_registry() megatron.bridge.models.conversion.mapping_registry.MegatronMappingRegistry#

Define parameter mappings between Megatron and HF formats.

HF param names use model.layers.* prefix (text-only CausalLM). The VLM bridge overrides this with model.language_model.layers.*.

_split_qkv_linear_out_weight(megatron_model, linear_out_weight)#

Override for Gemma4 dual-attention: detect global vs sliding layers by tensor size.

Gemma4 interleaves sliding-window and full (global) attention layers with different head configurations:

  • Sliding: kv_channels=256, num_query_groups=num_key_value_heads

  • Global: global_head_dim=512, num_global_key_value_heads=2, K=V tying

For global layers the linear_qkv LoRA output tensor is larger than the sliding expectation. We detect this and re-split using the global head dimensions. For global layers v_proj is set to ABSENT_PROJECTION because HF global attention has no v_proj weight (K=V tying); the export loop skips it.

bridge.models.gemma.gemma4_bridge._infer_attn_pattern(layer_types: list[str]) tuple[int, int]#

Infer (sliding, global) interleaved attention pattern from layer_types list.

E.g., [“sliding”, “sliding”, …, “full”, “sliding”, …] with 5 sliding + 1 full returns (5, 1).