bridge.models.gemma_vl.modeling_gemma4_vl#

Gemma 4 Vision-Language (VL) model wrapper for Megatron.

Combines a HuggingFace Gemma4 vision tower + multimodal embedder with a Megatron-Core GPT language model (Gemma 4 MoE).

Module Contents#

Classes#

Gemma4VLModel

Gemma 4 Vision-Language model wrapping HF vision tower + Megatron language model.

API#

class bridge.models.gemma_vl.modeling_gemma4_vl.Gemma4VLModel(
config: megatron.bridge.models.gpt_provider.GPTModelProvider,
pre_process: bool = True,
post_process: bool = True,
vp_stage: Optional[int] = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

Gemma 4 Vision-Language model wrapping HF vision tower + Megatron language model.

The vision tower and multimodal embedder (projector) are HF modules loaded via AutoModel.from_config. The language model is a Megatron-Core GPTModel constructed by the provider.

Forward flow: 1. Embed text tokens via language model embedding 2. If pixel_values provided: run vision tower → embed_vision → scatter into embeddings 3. Forward through language model decoder

Initialization

_init_embed_vision(config)#

Initialize the multimodal embedder (vision → language projection).

Gemma4’s embed_vision is: parameter-free RMSNorm → Linear(vision_hidden, text_hidden). We construct it using the HF Gemma4MultimodalEmbedder class.

set_input_tensor(input_tensor) None#

Set model chunk input tensor.

get_image_features(pixel_values, image_position_ids=None, **kwargs)#

Extract and project image features using HF vision tower + embedder.

Matches HF’s Gemma4Model.get_image_features: vision_tower returns last_hidden_state (already pooled + standardized), then embed_vision projects it to the language model’s hidden dimension.

forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_position_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.Tensor] = None,
runtime_gather_output: Optional[bool] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
*,
loss_mask: Optional[torch.Tensor] = None,
) tuple[torch.Tensor, torch.Tensor | None]#

Forward pass combining HF vision encoder with Megatron language model.

freeze(
freeze_language_model: bool,
freeze_vision_model: bool,
freeze_vision_projection: bool,
)#

Freeze model modules for fine-tuning.

_compute_attention_mask(
input_ids: torch.Tensor,
) Optional[torch.Tensor]#

Compute attention mask with bidirectional attention for image regions.