Adding New Model Support in Megatron-Bridge#
Phase 1: Discovery#
Step 1 β Get the HF model link#
Ask the user for the HuggingFace model link (e.g. https://huggingface.co/Qwen/Qwen3.5-VL-27B).
If the model is not public, ask the user to provide the config.json file directly.
Step 2 β Fetch and analyze config.json#
Read the modelβs config.json from HuggingFace (or from the user-provided file). Key fields to extract:
model_typeβ used for@register_bridge(model_type=...)architecturesβ the HF model class name (used forsource=...in registration)tie_word_embeddingsβ critical for weight tyingArchitecture fields:
num_hidden_layers,hidden_size,intermediate_size,num_attention_heads,num_key_value_heads,vocab_size,max_position_embeddings,rope_theta, etc.MoE fields (if present):
num_local_experts,num_experts_per_tok,moe_intermediate_sizeMLA fields (if present):
q_lora_rank,kv_lora_rank,qk_nope_head_dim,qk_rope_head_dim
If there are config fields you donβt recognize from previously supported models (check CONFIG_MAPPING in model_bridge.py and existing bridges), this likely indicates a new architectural block (e.g., a novel attention variant, custom normalization, or a new layer type). Ask the user to provide the HuggingFace modeling_*.py implementation of that block so you can understand the computation and create the correct Megatron-side mapping or custom module.
Step 3 β Determine VLM vs LLM#
VLM (Vision-Language Model) if config.json contains:
text_configANDvision_configsub-configsNote: VLMs may or may not have βVLβ in the name
LLM (Text-only) if:
No
text_config/vision_configSingle flat config for the language model
This distinction affects:
Which files to create (VLMs need a model.py combining vision + language)
Where to read config fields from (
text_configvs top-level for VLMs)Test patterns (VLMs need vision inputs in functional tests)
Step 4 β Check for quantized weights (FP8 / FP4)#
Inspect the HF checkpointβs model.safetensors (or model.safetensors.index.json) for quantized
weight dtypes such as float8_e4m3fn (FP8) or uint8/uint4 with accompanying *_scale_inv or
*_scale tensors. Common signs:
config.jsonmentionsquantization_configor dtype fields like"torch_dtype": "float8_e4m3fn"Safetensors contain
weight_scale_invkeys alongside the main weight keysThe model card mentions FP8/FP4/INT4 weights
Why this matters: The bridgeβs import_ckpt path does not automatically dequantize β it
loads raw quantized values as-is. This produces a silently broken model (random-level loss, huge
grad norms) instead of raising an error.
Fix: Dequantize before conversion. Two approaches:
Standalone script (recommended for user-facing models) β Write a
dequant_fp8_for_bridge.pyin the modelβs examples folder. Reference:examples/models/vlm/ministral3/dequant_fp8_for_bridge.py. The pattern is:w_bf16 = fp8_weight.to(bfloat16) * weight_scale_inv.In-bridge hook β Override
maybe_modify_loaded_hf_weight()in the bridge class to dequantize on the fly during import:def maybe_modify_loaded_hf_weight(self, hf_param, hf_state_dict): weight = hf_state_dict[hf_param] scale_key = hf_param + "_scale_inv" if weight.dtype == torch.float8_e4m3fn and scale_key in hf_state_dict: return weight.to(torch.bfloat16) * hf_state_dict[scale_key].to(torch.bfloat16) return weight
Always add a sanity check in the verification workflow (e.g., print std of a weight tensor β
quantized models typically have std β 13 before dequantization vs std β 0.006 after).
Phase 2: Bridge Support#
File structure#
LLM β Reference: Qwen2 (src/megatron/bridge/models/qwen/qwen2_bridge.py)
src/megatron/bridge/models/<model>/
βββ __init__.py
βββ <model>_bridge.py # Config + weight mappings
βββ <model>_provider.py # (optional) Only if custom provide() or recipe presets needed
VLM β Reference: Qwen3.5-VL (src/megatron/bridge/models/qwen_vl/)
src/megatron/bridge/models/<model>/
βββ __init__.py
βββ <model>_bridge.py # Config + weight mappings
βββ <model>_provider.py # Megatron config + model construction
βββ modelling_<model>/ # If using Megatron vision encoder
βββ __init__.py
βββ model.py # Combines vision + language
OR with HF vision encoder (Reference: Gemma3-VL):
src/megatron/bridge/models/<model>/
βββ __init__.py
βββ <model>_bridge.py
βββ <model>_provider.py
βββ modeling_<model>.py # HF vision + Megatron language wrapper
Implementation order#
LLM:
Bridge β Register bridge, implement
provider_bridge()andmapping_registry(). The bridge callssuper().provider_bridge()to get aGPTModelProviderfromCONFIG_MAPPING, then sets model-specific attributes on it. No separate provider file needed for most models.Provider (optional) β Only if the model needs extra dataclass fields for serialization, custom
provide()logic, or predefined size variants for recipes.
VLM:
Provider β VLMs always need a custom provider subclass with a custom
provide()that instantiates the combined vision+language model.Bridge β Register bridge with
provider=MyVLModelProvider. The bridge manually callshf_config_to_provider_kwargs(text_config)and instantiates the custom provider.Model class β Combine vision encoder + language decoder.
For detailed patterns, see:
VLM: vlm-patterns.md
LLM: llm-patterns.md
Critical: tie_word_embeddings for VLMs#
For VLMs, tie_word_embeddings lives on the top-level HF config, NOT on text_config. Always read from the parent config:
provider.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False)
Critical: Config field location for VLMs#
When reading HF config for VLMs, check whether each field is in:
hf_config(top-level) β e.g.tie_word_embeddings,image_token_id,video_token_idhf_config.text_configβ e.g.num_hidden_layers,hidden_size, etc.hf_config.vision_configβ e.g. vision encoder dimensions
Update FLOPs calculator for new architectural blocks#
If the model introduces a new computational block that differs from standard attention or MLP
(e.g., Gated DeltaNet / GDN linear attention, Multi-Token Prediction / MTP heads, Mamba SSM layers),
update the FLOPs calculator in src/megatron/bridge/training/utils/flop_utils.py so that
training throughput metrics (TFLOPs/GPU) are accurate.
When to update: Any time the new block has different FLOPs-per-token than standard self-attention or standard MLP. Common cases:
Linear attention variants (GDN, RetNet, RWKV) β replace the
O(sΒ²)attention term with the blockβs actual operation countMTP / speculative decoding heads β add FLOPs for the extra projection and norm layers
SSM layers (Mamba) β different recurrence FLOPs than attention
Novel MoE routing β may change the effective expert count
How to update:
Read the existing
transformer_flops()function inflop_utils.pyto understand the structure.Add a conditional block gated on a config attribute (e.g.,
experimental_attention_variant,mtp_num_layers). Follow the existing MoE pattern for config validation β raise on invalid types, assert list lengths, and use direct attribute access instead ofgetattrwith fallback defaults so that misconfigurations fail explicitly.Compute the per-layer FLOPs for the new block and blend it with the standard attention term based on the layer pattern.
Add unit tests in
tests/unit_tests/training/utils/test_flop_utils.pythat verify:New-block FLOPs differ from pure-attention baseline
Exact formula matches hand-computed expected values
Varying the block ratio (e.g.,
linear_attention_freq) changes FLOPs
Reference PR: #2925 β GDN FLOPs calculator adds GDN support with both the calculator code and comprehensive tests.
Phase 3: Recipe Support#
Recipes provide pre-configured training settings for each model size.
LLM recipes: src/megatron/bridge/recipes/<family>/<model>.py
VLM recipes: src/megatron/bridge/recipes/<family>/<model>.py
Each recipe file defines functions for each model size + training mode:
<model>_<size>_sft_config()β Full supervised fine-tuning<model>_<size>_peft_config()β LoRA/DoRA parameter-efficient fine-tuning<model>_<size>_pretrain_config()β Pretraining (LLM only, usually)
For detailed recipe patterns, see recipe-patterns.md.
Export checklist#
Family
__init__.pyβ import and add to__all__Top-level
src/megatron/bridge/recipes/__init__.pyβ wildcard importtrain_any_basic.pyβ add toconfig_map, docstring, and--modelchoices
Phase 4: Tests#
Unit tests (no GPU)#
tests/unit_tests/models/<model>/
βββ __init__.py
βββ test_<model>_bridge.py # Mock HF config β verify provider mapping
βββ test_<model>_provider.py # (optional) Only if custom provider subclass exists
Functional tests (GPU)#
tests/functional_tests/models/<model>/
βββ __init__.py
βββ test_<model>_conversion.py # Toy model HFβMegatron roundtrip
βββ test_<model>_provider.py # compare_provider_configs (optional)
For detailed test patterns, see tests-and-examples.md.
Phase 5: Docs and Examples#
Examples#
LLM examples: examples/models/<model>/
VLM examples: examples/models/vlm/<model>/
examples/models/<model>/ # LLM
examples/models/vlm/<model>/ # VLM
βββ README.md
βββ conversion.sh # HFβMegatron conversion commands (real model)
βββ inference.sh # Generation commands (real model, reasonable output)
βββ slurm_sft.sh # SFT training on SLURM
βββ slurm_peft.sh # PEFT training on SLURM
Key deliverable requirement: conversion.sh and inference.sh must target a real published model (e.g. Qwen/Qwen3-8B, not a toy). The inference script must produce reasonable output β for LLMs a coherent text continuation, for VLMs a plausible image description. This is the acceptance bar: conversion runs cleanly and generation makes sense.
Documentation#
Add a model page at docs/models/<type>/<model>.md covering:
Supported variants and sizes
Conversion commands
Training examples (SFT, PEFT)
Known limitations
Verification Workflow#
After implementing bridge support, prompt the user to run these commands on the cluster:
1. Smoke test (single GPU)#
uv run python -c "
from megatron.bridge import AutoBridge
bridge = AutoBridge.from_hf_pretrained('<org>/<model>')
provider = bridge.to_megatron_provider()
provider.tensor_model_parallel_size = 1
provider.pipeline_model_parallel_size = 1
provider.finalize()
model = provider.provide_distributed_model(wrap_with_ddp=False)
bridge.load_hf_weights(model)
for i, (name, tensor) in enumerate(bridge.export_hf_weights(model, cpu=True)):
print(name, tuple(tensor.shape))
if i > 10: break
"
2. Conversion roundtrip (multi-GPU)#
uv run python examples/conversion/convert_checkpoints.py import \
--hf-model <org>/<model> \
--megatron-path /workspace/<model> \
--torch-dtype bfloat16
uv run python examples/conversion/convert_checkpoints.py export \
--hf-model <org>/<model> \
--megatron-path /workspace/<model>/iter_0000000 \
--hf-path /workspace/<model>-hf-export
3. Generation test#
For LLMs:
uv run python examples/conversion/hf_to_megatron_generate_text.py \
--hf_model_path <org>/<model> --prompt "Hello"
For VLMs:
uv run python examples/conversion/hf_to_megatron_generate_vlm.py \
--hf_model_path <org>/<model> \
--image_path "https://example.com/image.jpeg" \
--prompt "Describe this image."
4. Run tests#
uv run python -m pytest tests/unit_tests/models/<model>/ -v
uv run python -m pytest tests/functional_tests/models/<model>/ -v --run-gpu
Quick Decision Tree#
User wants to add a model
β
ββ Has HF link? βββ No βββ Ask for link (or config.json if private)
β
ββ Has text_config + vision_config? βββ Yes βββ VLM path
β ββ Has Megatron vision encoder? βββ Megatron encoder (Qwen3.5 pattern)
β ββ No Megatron encoder βββ HF encoder (Gemma3 pattern)
β
ββ No vision config βββ LLM path (Qwen2 / GPT-OSS pattern)
ββ Standard GPT-style? βββ Bridge only (no provider subclass needed)
ββ Custom components? βββ Bridge + custom provider or modeling module