VLM Bridge Patterns#
Reference implementations:
Megatron vision encoder: Qwen3.5-VL (
src/megatron/bridge/models/qwen_vl/)HF vision encoder: Gemma3-VL (
src/megatron/bridge/models/gemma_vl/)
Provider Pattern#
Subclass GPTModelProvider. VLM providers add vision-specific fields on top of standard LLM fields.
@dataclass
class MyVLModelProvider(GPTModelProvider):
# Vision config (passed as a HF config object)
vision_config: Optional[Any] = None
# VLM-specific token IDs
image_token_id: Optional[int] = None
video_token_id: Optional[int] = None
# Freeze options
freeze_language_model: bool = False
freeze_vision_model: bool = False
freeze_vision_projection: bool = False
# Whether to use HF vision model (vs Megatron)
use_hf_vision_model: bool = False
def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MyVLModel:
# Build language layer spec
language_transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(...)
# Build vision config if needed
# Instantiate combined model
model = MyVLModel(config=self, ...)
if self.freeze_language_model or self.freeze_vision_model or self.freeze_vision_projection:
model.freeze(self.freeze_language_model, self.freeze_vision_model, self.freeze_vision_projection)
return model
def provide_language_model(self, pre_process=None, post_process=None, vp_stage=None):
"""Returns language-only model (for text-only inference)."""
return GPTModel(config=self, ...)
def validate_parallelism(self):
if self.num_query_groups < self.tensor_model_parallel_size:
raise ValueError(f"TP ({self.tensor_model_parallel_size}) must be <= num_query_groups ({self.num_query_groups})")
Key provider fields by source#
Read these from the correct config level:
Field |
Source (VLM) |
Notes |
|---|---|---|
|
|
Core architecture |
|
|
Attention config |
|
|
Tokenizer/position |
|
|
RoPE |
|
top-level |
CRITICAL: not text_config |
|
top-level |
Vision encoder config |
|
top-level |
Special token IDs |
Bridge Pattern#
@MegatronModelBridge.register_bridge(
source="MyModelForConditionalGeneration", # HF class name (string if not importable)
target=MyVLModel, # Megatron model class
provider=MyVLModelProvider, # Provider class
model_type="my_model", # HF model_type for export
)
class MyVLBridge(MegatronModelBridge):
def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> MyVLModelProvider:
hf_config = hf_pretrained.config
text_config = hf_config.text_config
# Map text config to provider kwargs using base class helper
provider_kwargs = self.hf_config_to_provider_kwargs(text_config)
provider = MyVLModelProvider(**provider_kwargs)
# CRITICAL: tie_word_embeddings from top-level config
provider.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False)
# Vision config
provider.vision_config = hf_config.vision_config
# VLM-specific fields from top-level config
provider.image_token_id = getattr(hf_config, "image_token_id", None)
provider.video_token_id = getattr(hf_config, "video_token_id", None)
return provider
def mapping_registry(self) -> MegatronMappingRegistry:
return MegatronMappingRegistry(
# Language model mappings (prefixed with language_model.*)
AutoMapping(megatron_param="language_model.embedding.word_embeddings.weight",
hf_param="model.embed_tokens.weight"),
AutoMapping(megatron_param="language_model.output_layer.weight",
hf_param="model.lm_head.weight"),
# ... language decoder layers ...
QKVMapping(
megatron_param="language_model.decoder.layers.*.self_attention.linear_qkv.weight",
q="model.language_model.layers.*.self_attn.q_proj.weight",
k="model.language_model.layers.*.self_attn.k_proj.weight",
v="model.language_model.layers.*.self_attn.v_proj.weight",
),
# Vision model mappings
AutoMapping(megatron_param="vision_model.patch_embed.proj.**",
hf_param="model.visual.patch_embed.proj.**"),
# ... vision layers ...
)
Import types#
from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM # VLM
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM # LLM
VLM Model Class Patterns#
Option A: Megatron Vision Encoder (Qwen3.5 pattern)#
Both vision and language use Megatron modules. Full parallelism support.
class MyVLModel(MegatronModule):
def __init__(self, config, pre_process=True, post_process=True, ...):
if pre_process:
self.vision_model = MyVisionModel(config.vision_config, ...)
self.language_model = MyGPTModel(config, ...)
def forward(self, input_ids, pixel_values, image_grid_thw, ...):
# 1. Vision: pixel_values → vision_embeds
vision_embeds = self.vision_model(pixel_values, image_grid_thw)
# 2. Text embeddings
text_embeds = self.language_model.embedding(input_ids)
# 3. Scatter vision into text at image token positions
combined = text_embeds.clone()
combined[vision_mask] = vision_embeds
# 4. Language model forward
return self.language_model(decoder_input=combined, ...)
def freeze(self, freeze_language, freeze_vision, freeze_projection):
if freeze_language:
for p in self.language_model.parameters(): p.requires_grad = False
if freeze_vision:
for p in self.vision_model.parameters(): p.requires_grad = False
# projection freeze logic
Option B: HF Vision Encoder (Gemma3 pattern)#
HF vision encoder + Megatron projector + Megatron language model. Simpler to implement.
class MyVLModel(MegatronModule):
def __init__(self, config, pre_process=True, post_process=True, ...):
if pre_process:
self.vision_tower = AutoModel.from_config(config.vision_config)
hook_hf_module_setattr_for_tp_grad_sync(self.vision_tower)
self.multi_modal_projector = MyProjector(config)
self.language_model = config.provide_language_model(pre_process, post_process)
def forward(self, input_ids, pixel_values, ...):
text_embeds = self.language_model.embedding(input_ids)
if pixel_values is not None:
image_features = self.vision_tower(pixel_values).pooler_output
image_features = self.multi_modal_projector(image_features)
text_embeds.masked_scatter_(special_image_mask, image_features)
return self.language_model(decoder_input=text_embeds, ...)
Weight Mapping Naming Conventions#
VLM weight names typically have these prefixes:
Megatron prefix |
HF prefix |
Component |
|---|---|---|
|
|
Text decoder |
|
|
Text embeddings |
|
|
Output head |
|
|
Vision encoder |
Check the actual HF model’s state_dict() keys to determine exact naming.
Common Mapping Types for VLMs#
Mapping Class |
Use Case |
|---|---|
|
1:1 name mapping (most weights) |
|
Fused Q/K/V projections |
|
Vision QKV (different from language) |
|
gate_proj + up_proj → linear_fc1 |
|
Weights replicated across TP ranks (e.g. patch_embed) |
|
MoE gate+up projections |
|
MoE down projections |