bridge.models.gemma_vl.gemma4_vl_bridge#

Megatron Bridge for Gemma 4 VL (Vision-Language).

Extends the Gemma 4 text bridge to handle the full VLM checkpoint with vision tower, multimodal embedder, and language model.

Weight prefixes in HF VLM checkpoint (after stripping outer model.):

  • language_model.layers.* → language model decoder layers

  • language_model.embed_tokens → language model embedding

  • language_model.norm → final layernorm

  • vision_tower.* → HF vision encoder (replicated)

  • embed_vision.* → multimodal projector (replicated)

Module Contents#

Classes#

Gemma4VLBridge

Megatron Bridge for Gemma 4 Vision-Language models.

API#

class bridge.models.gemma_vl.gemma4_vl_bridge.Gemma4VLBridge#

Bases: megatron.bridge.models.conversion.model_bridge.MegatronModelBridge

Megatron Bridge for Gemma 4 Vision-Language models.

Handles conversion between HuggingFace Gemma4ForConditionalGeneration and Megatron-Core Gemma4VLModel.

.. rubric:: Example

from megatron.bridge import AutoBridge bridge = AutoBridge.from_hf_pretrained(“google/gemma-4-26B-A4B”) provider = bridge.to_megatron_provider()

provider_bridge(
hf_pretrained: megatron.bridge.models.hf_pretrained.vlm.PreTrainedVLM,
) megatron.bridge.models.gemma_vl.gemma4_vl_provider.Gemma4VLModelProvider#
maybe_modify_converted_hf_weight(
task,
converted_weights_dict,
hf_state_dict,
)#

Un-fuse fused weights and drop synthesized keys on export.

On import, maybe_modify_loaded_hf_weight applies two non-trivial fusions to the MoE layers to simplify the MCore forward pass:

  1. Router fusion: mg = hf * (scale * sqrt_hidden⁻¹ / pffl2)

  2. Shared-expert gate/up fusion: mg = hf * (pffl / pffl2)

On export (Megatron → HF), this method inverts both fusions so the resulting HF weights exactly match the original checkpoint. It also drops the synthesized v_proj key produced by QKVMapping.megatron_to_hf for K=V global-attention layers where v_proj is absent in HF.

maybe_modify_loaded_hf_weight(
hf_param: str | dict[str, str],
hf_state_dict: Mapping[str, torch.Tensor],
) torch.Tensor#

Handle special weight loading for Gemma 4 VLM.

K=V synthesis for global attention layers, router weight fusion, and shared expert pre-norm fusion.

HF param names have model.language_model. prefix (raw safetensors keys include the outer model. from Gemma4ForConditionalGeneration).

_fuse_router_weight(
hf_param: str,
hf_state_dict: Mapping[str, torch.Tensor],
) torch.Tensor#

Fuse router preprocessing into projection weight (VLM version).

_fuse_shared_expert_prenorm(
hf_param: dict[str, str],
hf_state_dict: Mapping[str, torch.Tensor],
) dict[str, torch.Tensor]#

Fuse pre-norm correction into shared expert gate/up weights (VLM version).

mapping_registry() megatron.bridge.models.conversion.mapping_registry.MegatronMappingRegistry#

Define parameter mappings for Gemma 4 VLM.

HF VLM param names (raw safetensors keys include outer model. prefix):

  • model.language_model.layers.* → language model

  • model.vision_tower.* → vision encoder (replicated)

  • model.embed_vision.* → multimodal projector (replicated)

_split_qkv_linear_out_weight(megatron_model, linear_out_weight)#

Override for Gemma4 dual-attention: detect global vs sliding layers by tensor size.

Gemma4 interleaves sliding-window and full (global) attention layers with different head configurations:

  • Sliding: kv_channels=256, num_query_groups=num_key_value_heads

  • Global: global_head_dim=512, num_global_key_value_heads=2, K=V tying

For global layers the linear_qkv LoRA output tensor is larger than the sliding expectation. We detect this and re-split using the global head dimensions. For global layers v_proj is set to ABSENT_PROJECTION because HF global attention has no v_proj weight (K=V tying); the export loop skips it.