bridge.models.conversion.model_bridge
#
Module Contents#
Classes#
Tuple representing a Megatron model weight with its metadata. |
|
Tuple representing a HuggingFace model weight with its metadata. |
|
A unified task for converting weights between HuggingFace and Megatron formats. |
|
High-level orchestrator for HuggingFace ↔ Megatron model conversions. |
Functions#
Adjust layer number and expert number from local to global numbering. |
|
Check if a parameter is tensor parallel distributed. |
|
Get the appropriate model bridge for a given HuggingFace architecture. |
|
Bridge Megatron model state to HuggingFace format. |
|
Register a bridge implementation with the dispatch system. |
|
Create a decorator for registering bridge implementations. |
Data#
API#
- bridge.models.conversion.model_bridge.logger#
‘getLogger(…)’
- bridge.models.conversion.model_bridge.MappingT#
‘TypeVar(…)’
- bridge.models.conversion.model_bridge.HFPreTrained#
‘TypeVar(…)’
- bridge.models.conversion.model_bridge.ModelProviderTarget#
‘TypeVar(…)’
- bridge.models.conversion.model_bridge.MegatronModel#
‘TypeVar(…)’
- bridge.models.conversion.model_bridge._BridgeImplClass#
‘TypeVar(…)’
- class bridge.models.conversion.model_bridge.MegatronWeightTuple#
Bases:
typing.NamedTuple
Tuple representing a Megatron model weight with its metadata.
- param_name: str#
None
- weight: torch.Tensor#
None
- vp_stage: int#
None
- class bridge.models.conversion.model_bridge.HFWeightTuple#
Bases:
typing.NamedTuple
Tuple representing a HuggingFace model weight with its metadata.
- param_name: str#
None
- weight: torch.Tensor#
None
- class bridge.models.conversion.model_bridge.WeightConversionTask#
Bases:
typing.Generic
[bridge.models.conversion.model_bridge.MappingT
]A unified task for converting weights between HuggingFace and Megatron formats.
This class combines both HF->Megatron and Megatron->HF conversion tasks since they have different method names (hf_to_megatron vs megatron_to_hf) and can coexist safely.
The task encapsulates all information needed for weight conversion in either direction, with different fields being relevant depending on the conversion type.
.. attribute:: param_name
unwrapped, local parameter name (no
module.
prefixes).- Type:
str
.. attribute:: mapping
Concrete :pyclass:
MegatronParamMapping
instance responsible for weight transformation and distribution.- Type:
MappingT
.. attribute:: pp_rank
Pipeline-parallel rank that owns the parameter (required for saves).
- Type:
Optional[int]
.. attribute:: vp_stage
Virtual-pipeline stage index (required for loads).
- Type:
Optional[int]
.. attribute:: megatron_module
Reference to the Megatron model or sub-module that owns the parameter (required for loads).
- Type:
Optional[torch.nn.Module]
.. attribute:: param_weight
The actual parameter tensor that will receive the converted weight (required for loads).
- Type:
Optional[torch.Tensor]
- param_name: str#
None
- mapping: bridge.models.conversion.model_bridge.MappingT#
None
- pp_rank: Optional[int]#
None
- vp_stage: Optional[int]#
None
- megatron_module: Optional[torch.nn.Module]#
None
- param_weight: Optional[torch.Tensor]#
None
- bridge.models.conversion.model_bridge._megatron_local_name_to_global(
- models: megatron.core.transformer.module.MegatronModule | List[megatron.core.transformer.module.MegatronModule],
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- param_name: str,
- vp_stage: Optional[int] = None,
Adjust layer number and expert number from local to global numbering.
- class bridge.models.conversion.model_bridge.MegatronModelBridge#
Bases:
typing.Generic
[bridge.models.conversion.model_bridge.HFPreTrained
,bridge.models.conversion.model_bridge.ModelProviderTarget
,bridge.models.conversion.model_bridge.MegatronModel
]High-level orchestrator for HuggingFace ↔ Megatron model conversions.
This abstract base class provides the framework for converting models between HuggingFace and Megatron formats. It acts as an orchestrator that coordinates the conversion process without directly handling the complex details of tensor parallelism or weight transformations.
The bridge pattern separates concerns:
MegatronModelBridge: Orchestrates the overall conversion process
MegatronMappingRegistry: Manages parameter name mappings
MegatronParamMapping: Handles actual weight transformations and distribution
Key responsibilities:
Build conversion tasks that map each parameter to its appropriate bridge
Execute tasks with proper error handling and progress tracking
Provide utilities for configuration translation
Handle virtual pipeline parallelism (VP) complexities
To implement a bridge for a new model architecture:
Create a subclass decorated with @MegatronModelBridge.register_bridge:
.. code-block:: python
@MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel) class MegatronCausalLlamaBridge(MegatronModelBridge): pass
Implement provider_bridge to create Megatron configurations:
.. code-block:: python
def provider_bridge(self, hf_pretrained) -> LlamaModelProvider: return LlamaModelProvider( num_layers=hf_pretrained.config.num_hidden_layers, hidden_size=hf_pretrained.config.hidden_size, ... )
Implement mapping_registry to define weight mappings:
.. code-block:: python
def mapping_registry(self) -> MegatronMappingRegistry: return MegatronMappingRegistry( AutoMapping( megatron_param="embedding.word_embeddings.weight", hf_param="model.embed_tokens.weight" ), ... )
.. rubric:: Example
.. code-block:: python
# The bridge is typically not instantiated directly # Instead, use AutoBridge or AutoBridge which handle this bridge = AutoBridge.from_hf_pretrained("meta-llama/Llama-3-8B") provider = bridge.to_megatron_provider()
.. note::
This class uses generic type parameters to ensure type safety:
HFPreTrained: The HuggingFace model type
ModelProviderTarget: The Megatron model provider type
MegatronModel: The Megatron model type
- abstractmethod provider_bridge(
- hf_pretrained: bridge.models.conversion.model_bridge.HFPreTrained,
Create a Megatron model provider from HuggingFace configuration.
This abstract method must be implemented by subclasses to translate HuggingFace model configurations into Megatron model provider instances. The provider contains all necessary configuration for creating Megatron models.
- Parameters:
hf_pretrained (HFPreTrained) – HuggingFace model or configuration containing the source model’s architecture details.
- Returns:
A configured model provider instance (e.g., GPTModelProvider, LlamaModelProvider) ready to create Megatron models.
- Return type:
ModelProviderTarget
.. rubric:: Example
.. code-block:: python
def provider_bridge(self, hf_pretrained): return LlamaModelProvider( num_layers=hf_pretrained.config.num_hidden_layers, hidden_size=hf_pretrained.config.hidden_size, num_attention_heads=hf_pretrained.config.num_attention_heads, ffn_hidden_size=hf_pretrained.config.intermediate_size, # ... other configuration mappings )
- abstractmethod mapping_registry() megatron.bridge.models.conversion.mapping_registry.MegatronMappingRegistry #
Define weight mappings between HuggingFace and Megatron formats.
This abstract method must be implemented by subclasses to specify how parameters map between the two formats. The returned MegatronMappingRegistry contains all param mappings needed for the model architecture.
- Returns:
MegatronMappingRegistry containing all weight mapping definitions.
- Return type:
.. rubric:: Example
.. code-block:: python
def mapping_registry(self): return MegatronMappingRegistry( AutoMapping( megatron_param="embedding.word_embeddings.weight", hf_param="model.embed_tokens.weight" ), QKVMapping( megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", q="model.layers.*.self_attn.q_proj.weight", k="model.layers.*.self_attn.k_proj.weight", v="model.layers.*.self_attn.v_proj.weight" ), # ... more param mappings )
- _megatron_global_param_names_all_pp_ranks(
- megatron_model: Union[bridge.models.conversion.model_bridge.MegatronModel, List[bridge.models.conversion.model_bridge.MegatronModel]],
Get all parameter names across all pipeline parallel ranks.
- _with_progress_tracking(
- tasks,
- description: str,
- show_progress: bool = True,
Helper method to wrap an iterable with progress tracking.
- Parameters:
tasks – Iterable of tasks to process
description – Description for the progress bar
show_progress – Whether to show progress (defaults to True)
- Yields:
Items from the tasks iterable while updating progress
- load_weights_hf_to_megatron(
- hf_pretrained: bridge.models.conversion.model_bridge.HFPreTrained,
- megatron_model: Union[bridge.models.conversion.model_bridge.MegatronModel, List[bridge.models.conversion.model_bridge.MegatronModel]],
Load HuggingFace weights into Megatron models.
This method orchestrates the complete weight loading process from HuggingFace format to Megatron’s distributed format. It builds a conversion task and executes it with proper progress tracking and error handling.
The actual weight transformations and distribution are delegated to the appropriate MegatronParamMapping instances based on the state mappings.
- Parameters:
hf_pretrained (HFPreTrained) – HuggingFace model or state source containing the weights to load.
megatron_model (Union[MegatronModel, List[MegatronModel]]) – Megatron model instance or list of model instances (one per virtual pipeline stage).
- Returns:
The input megatron_model as a list with loaded weights.
- Return type:
List[MegatronModel]
Process:
Build a task mapping each Megatron parameter to its source
For each parameter in the task:
Fetch source weights from HuggingFace state
Apply format transformation via the param mapping
Distribute to appropriate TP/PP ranks
Copy into the Megatron parameter
.. rubric:: Example
.. code-block:: python
hf_model = PreTrainedCausalLM.from_pretrained("gpt2") megatron_model = create_megatron_model() # Single model or list bridge.load_weights_hf_to_megatron(hf_model, megatron_model)
.. note::
Progress is shown only on rank 0 to avoid cluttered output in distributed environments.
- Raises:
ValueError – If hf_pretrained doesn’t have state attribute or if weight shapes don’t match.
AttributeError – If required HF weights are missing.
- stream_weights_hf_to_megatron(
- hf_pretrained: bridge.models.conversion.model_bridge.HFPreTrained,
- megatron_model: Union[bridge.models.conversion.model_bridge.MegatronModel, List[bridge.models.conversion.model_bridge.MegatronModel]],
- conversion_tasks: Optional[List[bridge.models.conversion.model_bridge.WeightConversionTask]] = None,
Generator variant of load_weights_hf_to_megatron for streaming weight conversion.
This method provides a memory-efficient way to convert weights by yielding them one at a time instead of loading all at once. Useful for processing very large models or when implementing custom weight handling logic.
- Parameters:
hf_pretrained (HFPreTrained) – HuggingFace model or state source containing the weights.
megatron_model (Union[MegatronModel, List[MegatronModel]]) – Megatron model instance or list of model instances to extract configuration from.
conversion_tasks (Optional[List[WeightConversionTask]]) – Pre-built conversion tasks. If not provided, tasks will be built automatically from the models.
- Yields:
MegatronWeightTuple –
Named tuples containing: - vp_stage: Index of the model in megatron_model list - param_name: Name of the parameter - weight: Transformed weight tensor for this rank
.. rubric:: Example
.. code-block:: python
# Process weights one by one for weight_tuple in bridge.stream_weights_hf_to_megatron(hf_model, megatron_model): print(f"Processing {weight_tuple.param_name}: {weight_tuple.weight.shape}") # Custom processing logic here # Or use pre-built conversion tasks tasks = bridge.build_conversion_tasks(hf_model, megatron_model) for weight_tuple in bridge.stream_weights_hf_to_megatron(hf_model, megatron_model, tasks): print(f"Processing {weight_tuple.param_name}: {weight_tuple.weight.shape}")
.. note:: Only yields weights that belong to the current rank after TP/PP distribution.
- Raises:
ValueError – If input parameters are invalid.
- stream_weights_megatron_to_hf(
- megatron_model: Union[bridge.models.conversion.model_bridge.MegatronModel, List[bridge.models.conversion.model_bridge.MegatronModel]],
- hf_pretrained: bridge.models.conversion.model_bridge.HFPreTrained,
- cpu: bool = True,
- show_progress: bool = True,
- conversion_tasks: Optional[List[bridge.models.conversion.model_bridge.WeightConversionTask]] = None,
Export Megatron weights to HuggingFace format.
This method orchestrates the conversion of weights from Megatron’s distributed format back to HuggingFace format. It handles gathering from tensor parallel ranks, broadcasting across pipeline parallel ranks, and format conversions. All ranks receive the full tensors.
The export order is determined automatically:
First tries safetensors order (if key_to_filename_map is available)
Falls back to HuggingFace state dict order
- Parameters:
megatron_model (Union[MegatronModel, List[MegatronModel]]) – Megatron model instance or list of model instances (one per virtual pipeline stage).
hf_pretrained (HFPreTrained) – HuggingFace model/config for metadata and mapping info.
cpu (bool, optional) – Whether to move tensors to CPU before yielding. Defaults to True.
show_progress (bool, optional) – Display progress bar during export. Defaults to True.
conversion_tasks (Optional[List[WeightConversionTask]]) – Pre-built conversion tasks. If not provided, tasks will be built automatically from the models.
- Yields:
HFWeightTuple – Named tuples of (param_name, weight_tensor) in HF format.
.. rubric:: Example
.. code-block:: python
# Export weights for name, weight in bridge.stream_weights_megatron_to_hf(megatron_model, hf_config): print(f"Exported {name}: {weight.shape}") # Or use pre-built conversion tasks tasks = bridge.build_conversion_tasks(hf_config, megatron_model) for name, weight in bridge.stream_weights_megatron_to_hf( megatron_model, hf_config, conversion_tasks=tasks ): print(f"Exported {name}: {weight.shape}")
- Raises:
ValueError – If input parameters are invalid.
.. note:: All ranks yield the full tensors after gathering from distributed format.
- dtype_from_hf(config, default=None)#
Extract torch dtype from a HuggingFace config.
This utility method handles the conversion of dtype specifications in HuggingFace configs to PyTorch dtype objects. Supports both direct torch.dtype objects and string representations.
- Parameters:
config – HuggingFace configuration object with a torch_dtype attribute.
default (Any, optional) – Default value to return if torch_dtype is not str or torch.dtype. Defaults to None.
- Returns:
The corresponding PyTorch dtype.
- Return type:
torch.dtype
- Raises:
AssertionError – If config doesn’t have torch_dtype attribute.
ValueError – If torch_dtype is neither a string nor torch.dtype.
.. rubric:: Example
.. code-block:: python
dtype = bridge.dtype_from_hf(hf_config) print(dtype) # torch.float16
- dtype_from_str(dtype: str) torch.dtype #
Convert a string precision identifier to equivalent torch dtype.
This utility method handles various string representations of PyTorch data types, including common abbreviations and mixed precision formats.
- Parameters:
dtype (str) – String representation of dtype (e.g., “float16”, “fp16”, “bf16-mixed”).
- Returns:
Corresponding PyTorch dtype (defaults to float32 if unknown).
- Return type:
torch.dtype
Supported formats: - float16/fp16/16/16-mixed → torch.float16 - bfloat16/bf16-mixed → torch.bfloat16 - Others → torch.float32 (default)
.. rubric:: Example
.. code-block:: python
dtype = bridge.dtype_from_str("fp16") print(dtype) # torch.float16 dtype = bridge.dtype_from_str("bf16-mixed") print(dtype) # torch.bfloat16
- make_vocab_size_divisible_by(vocab_size: int) int #
Calculate an appropriate divisor for vocabulary size padding.
Megatron requires vocabulary sizes to be divisible by certain values for efficient tensor parallelism. This method finds the largest power of 2 (up to 128) that evenly divides the vocabulary size.
- Parameters:
vocab_size (int) – Original vocabulary size from the model.
- Returns:
Largest power of 2 (≤ 128) that divides vocab_size.
- Return type:
int
.. rubric:: Example
.. code-block:: python
# For vocab_size=50257 (GPT-2) divisor = bridge.make_vocab_size_divisible_by(50257) print(divisor) # 1 (50257 is prime) # For vocab_size=32000 (Llama) divisor = bridge.make_vocab_size_divisible_by(32000) print(divisor) # 128
.. note::
The returned value is used by Megatron to potentially pad the vocabulary to ensure efficient parallelization.
- _get_provider_from_model(
- model: megatron.core.transformer.module.MegatronModule,
Extract provider/config from model.
- _unwrap_name(name: str) str #
Unwrap name from DDP or other wrappers.
- Parameters:
name – Parameter name that may have ‘module.’ prefixes
- Returns:
Unwrapped parameter name with ‘module.’ prefixes removed
.. rubric:: Example
‘module.module.decoder.weight’ -> ‘decoder.weight’
- megatron_model: Union[bridge.models.conversion.model_bridge.MegatronModel, List[bridge.models.conversion.model_bridge.MegatronModel]],
Broadcast shared embeddings and output weights across embedding group.
When embeddings and output weights are shared and pipeline parallelism is enabled, this method ensures all ranks in the embedding group have the same weights by broadcasting from rank 0.
- Parameters:
megatron_model – Megatron model instance or list of model instances.
- build_conversion_tasks(
- hf_pretrained: bridge.models.conversion.model_bridge.HFPreTrained,
- megatron_model: List[bridge.models.conversion.model_bridge.MegatronModel],
Construct the conversion tasks between HF and megatron.
The algorithm walks over every parameter of every destination model, asks the :class:
MegatronMappingRegistry
whether it has a mapping for that parameter, and – if the corresponding HF weights actually exist – yields an :class:_HFLoadTask
describing exactly how that parameter will be populated.
- classmethod register_bridge(
- *,
- source: Type[transformers.modeling_utils.PreTrainedModel],
- target: Type[bridge.models.conversion.model_bridge.MegatronModel],
Class decorator for registering bridge implementations.
This decorator registers a MegatronModelBridge subclass with the dispatch system, enabling automatic routing of conversions based on the source HuggingFace model type and target Megatron model type.
- Parameters:
source (Type[PreTrainedModel]) – HuggingFace PreTrainedModel class (e.g., LlamaForCausalLM).
target (Type[MegatronModel]) – Megatron model class (e.g., GPTModel).
- Returns:
Decorator function that registers the bridge implementation.
- Return type:
Callable[[_BridgeImplClass], _BridgeImplClass]
.. rubric:: Example
.. code-block:: python
@MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel) class MegatronCausalLlamaBridge(MegatronModelBridge): def provider_bridge(self, hf_pretrained): # Implementation pass def mapping_registry(self): # Implementation pass
.. note::
The decorated class is registered with multiple dispatchers to handle different conversion scenarios. The registration is automatic when the class is defined.
- bridge.models.conversion.model_bridge.is_tensor_parallel(param) bool #
Check if a parameter is tensor parallel distributed.
- bridge.models.conversion.model_bridge.get_model_bridge(
- hf_architecture,
Get the appropriate model bridge for a given HuggingFace architecture.
- bridge.models.conversion.model_bridge.stream_weights_megatron_to_hf(
- dispatch_instance: bridge.models.conversion.model_bridge.MegatronModel,
- megatron_model: Union[bridge.models.conversion.model_bridge.MegatronModel, List[bridge.models.conversion.model_bridge.MegatronModel]],
- hf_pretrained: bridge.models.conversion.model_bridge.HFPreTrained,
- cpu: bool = True,
- show_progress: bool = True,
- conversion_tasks: Optional[List[bridge.models.conversion.model_bridge.WeightConversionTask]] = None,
Bridge Megatron model state to HuggingFace format.
- bridge.models.conversion.model_bridge.register_bridge_implementation(
- *,
- source: Type[transformers.modeling_utils.PreTrainedModel],
- target: Type[megatron.core.transformer.module.MegatronModule],
- bridge_class: Type[bridge.models.conversion.model_bridge.MegatronModelBridge],
Register a bridge implementation with the dispatch system.
- Parameters:
source – HuggingFace PreTrainedModel class (e.g., LlamaForCausalLM)
target – Megatron model class (e.g., GPTModel)
bridge_class – MegatronModelBridge implementation class
- bridge.models.conversion.model_bridge.create_bridge_decorator(
- *,
- source: Type[transformers.modeling_utils.PreTrainedModel],
- target: Type[megatron.core.transformer.module.MegatronModule],
Create a decorator for registering bridge implementations.
- Parameters:
source – HuggingFace PreTrainedModel class
target – Megatron model class
- Returns:
Decorator function that registers the bridge implementation