bridge.models.gemma.gemma4_bridge#
Megatron Bridge for Gemma 4 text-only (CausalLM).
Gemma 4 is a MoE model with hybrid sliding/global attention. The dense MLP
is mapped to Megatron-Core’s shared expert mechanism, and routed experts
use fused tensor format [num_experts, 2*intermediate, hidden].
Key architecture-specific handling:
K=V on global attention layers:
v_projis absent; K weights are copied to V.Dual pre-norms: separate norms for dense MLP vs routed experts.
Router scale/per_expert_scale: loaded as replicated buffers.
layer_scalar: per-layer scaling buffer.
Supported models
google/gemma-4-26B-A4B(MoE,enable_moe_block=True) — fully supported.
NOT supported
Dense Gemma 4 models (
enable_moe_block=False, e.g.google/gemma-4-e2b-it).gemma4_vl_bridge.pyraisesValueErrorfor non-MoE models. Dense support requires per-layerffn_hidden_sizeand Per-Layer Embeddings (PLE) in MCore.
Module Contents#
Classes#
QKV mapping that tolerates missing v_proj in the HF checkpoint. |
|
Megatron Bridge for Gemma 4 text-only (CausalLM). |
Functions#
Infer (sliding, global) interleaved attention pattern from layer_types list. |
API#
- class bridge.models.gemma.gemma4_bridge._Gemma4QKVMapping(*args, **kwargs)#
Bases:
megatron.bridge.models.conversion.param_mapping.QKVMappingQKV mapping that tolerates missing v_proj in the HF checkpoint.
Gemma 4 global attention layers share K=V, so v_proj is absent.
allow_hf_name_mismatch = Trueprevents the weight loader from skipping the entire QKV mapping; the V weights are synthesized from K inGemma4Bridge.maybe_modify_loaded_hf_weight.Initialization
- class bridge.models.gemma.gemma4_bridge.Gemma4Bridge#
Bases:
megatron.bridge.models.conversion.model_bridge.MegatronModelBridgeMegatron Bridge for Gemma 4 text-only (CausalLM).
Handles conversion between HuggingFace Gemma4ForCausalLM and Megatron-Core GPTModel with MoE + shared experts.
Architecture mapping:
Dense MLP → Megatron shared experts (
moe_shared_expert_overlap=False)Routed MoE → Megatron routed experts (fused expert format)
Sliding attention → standard kv_channels/num_query_groups
Global attention → overridden kv_channels/num_query_groups per layer
.. rubric:: Example
from megatron.bridge import AutoBridge bridge = AutoBridge.from_hf_pretrained(“google/gemma-4-12B-A2B”) provider = bridge.to_megatron_provider()
- provider_bridge(
- hf_pretrained: megatron.bridge.models.hf_pretrained.causal_lm.PreTrainedCausalLM,
Convert HuggingFace config to Gemma4ModelProvider.
- maybe_modify_converted_hf_weight(
- task,
- converted_weights_dict,
- hf_state_dict,
Un-fuse fused weights and drop synthesized keys on export.
On import, two non-trivial fusions are applied to the MoE layers:
Router fusion:
mg = hf * (scale * hidden^-0.5 / pffl2)Shared-expert gate/up fusion:
mg = hf * (pffl / pffl2)
This method inverts both fusions on export so the resulting HF weights exactly match the original checkpoint. It also drops the synthesized
v_projkey produced for K=V global-attention layers wherev_projis absent in HF.
- maybe_modify_loaded_hf_weight(
- hf_param: str | dict[str, str],
- hf_state_dict: Mapping[str, torch.Tensor],
Handle special weight loading for Gemma 4.
K=V on global attention layers: synthesize
v_projfromk_proj.Router weight fusion: absorb
router.scale * scalar_root_size / (1 + ln2_weight)intorouter.proj.weightso MCore’s router produces correct logits when receivingpre_feedforward_layernorm_2-normed input.Shared expert pre-norm fusion: absorb the ratio
(1 + pre_feedforward_layernorm) / (1 + pre_feedforward_layernorm_2)into shared expert gate/up weights so the shared expert effectively receivespre_feedforward_layernorm-normed input even though MCore feeds itpre_feedforward_layernorm_2-normed input.
- _fuse_router_weight(
- hf_param: str,
- hf_state_dict: Mapping[str, torch.Tensor],
Fuse router preprocessing into projection weight.
HF router: logits = proj(rms_norm(x) * scale * scalar_root_size) MCore router: logits = weight @ pre_feedforward_layernorm_2(x)
Since rms_norm(x) = pre_feedforward_layernorm_2(x) / ln2_weight (Gemma 4 uses standard gamma: x * w / rms(x)), we fuse: new_weight = proj.weight * (scale * scalar_root_size / ln2_weight)
- hf_param: dict[str, str],
- hf_state_dict: Mapping[str, torch.Tensor],
Fuse pre-norm correction into shared expert gate/up weights.
MCore feeds shared experts
pre_feedforward_layernorm_2(x)but HF feeds thempre_feedforward_layernorm(x). Since both norms are standard RMSNorm (x * w / rms(x)), the correction is element-wise:correction[j] = w_pffl[j] / w_pffl2[j] new_weight[i, j] = weight[i, j] * correction[j]
- mapping_registry() megatron.bridge.models.conversion.mapping_registry.MegatronMappingRegistry#
Define parameter mappings between Megatron and HF formats.
HF param names use
model.layers.*prefix (text-only CausalLM). The VLM bridge overrides this withmodel.language_model.layers.*.
- _split_qkv_linear_out_weight(megatron_model, linear_out_weight)#
Override for Gemma4 dual-attention: detect global vs sliding layers by tensor size.
Gemma4 interleaves sliding-window and full (global) attention layers with different head configurations:
Sliding: kv_channels=256, num_query_groups=num_key_value_heads
Global: global_head_dim=512, num_global_key_value_heads=2, K=V tying
For global layers the linear_qkv LoRA output tensor is larger than the sliding expectation. We detect this and re-split using the global head dimensions. For global layers
v_projis set toABSENT_PROJECTIONbecause HF global attention has no v_proj weight (K=V tying); the export loop skips it.
- bridge.models.gemma.gemma4_bridge._infer_attn_pattern(layer_types: list[str]) tuple[int, int]#
Infer (sliding, global) interleaved attention pattern from layer_types list.
E.g., [“sliding”, “sliding”, …, “full”, “sliding”, …] with 5 sliding + 1 full returns (5, 1).