bridge.models.conversion.model_bridge#

Module Contents#

Classes#

MegatronWeightTuple

Tuple representing a Megatron model weight with its metadata.

HFWeightTuple

Tuple representing a HuggingFace model weight with its metadata.

WeightConversionTask

A unified task for converting weights between HuggingFace and Megatron formats.

MegatronModelBridge

High-level orchestrator for HuggingFace ↔ Megatron model conversions.

Functions#

_megatron_local_name_to_global

Adjust layer number and expert number from local to global numbering.

is_tensor_parallel

Check if a parameter is tensor parallel distributed.

get_model_bridge

Get the appropriate model bridge for a given HuggingFace architecture.

stream_weights_megatron_to_hf

Bridge Megatron model state to HuggingFace format.

register_bridge_implementation

Register a bridge implementation with the dispatch system.

create_bridge_decorator

Create a decorator for registering bridge implementations.

Data#

API#

bridge.models.conversion.model_bridge.logger#

‘getLogger(…)’

bridge.models.conversion.model_bridge.MappingT#

‘TypeVar(…)’

bridge.models.conversion.model_bridge.HFPreTrained#

‘TypeVar(…)’

bridge.models.conversion.model_bridge.ModelProviderTarget#

‘TypeVar(…)’

bridge.models.conversion.model_bridge.MegatronModel#

‘TypeVar(…)’

bridge.models.conversion.model_bridge._BridgeImplClass#

‘TypeVar(…)’

class bridge.models.conversion.model_bridge.MegatronWeightTuple#

Bases: typing.NamedTuple

Tuple representing a Megatron model weight with its metadata.

param_name: str#

None

weight: torch.Tensor#

None

vp_stage: int#

None

class bridge.models.conversion.model_bridge.HFWeightTuple#

Bases: typing.NamedTuple

Tuple representing a HuggingFace model weight with its metadata.

param_name: str#

None

weight: torch.Tensor#

None

class bridge.models.conversion.model_bridge.WeightConversionTask#

Bases: typing.Generic[bridge.models.conversion.model_bridge.MappingT]

A unified task for converting weights between HuggingFace and Megatron formats.

This class combines both HF->Megatron and Megatron->HF conversion tasks since they have different method names (hf_to_megatron vs megatron_to_hf) and can coexist safely.

The task encapsulates all information needed for weight conversion in either direction, with different fields being relevant depending on the conversion type.

.. attribute:: param_name

unwrapped, local parameter name (no module. prefixes).

Type:

str

.. attribute:: mapping

Concrete :pyclass:MegatronParamMapping instance responsible for weight transformation and distribution.

Type:

MappingT

.. attribute:: pp_rank

Pipeline-parallel rank that owns the parameter (required for saves).

Type:

Optional[int]

.. attribute:: vp_stage

Virtual-pipeline stage index (required for loads).

Type:

Optional[int]

.. attribute:: megatron_module

Reference to the Megatron model or sub-module that owns the parameter (required for loads).

Type:

Optional[torch.nn.Module]

.. attribute:: param_weight

The actual parameter tensor that will receive the converted weight (required for loads).

Type:

Optional[torch.Tensor]

param_name: str#

None

mapping: bridge.models.conversion.model_bridge.MappingT#

None

pp_rank: Optional[int]#

None

vp_stage: Optional[int]#

None

megatron_module: Optional[torch.nn.Module]#

None

param_weight: Optional[torch.Tensor]#

None

bridge.models.conversion.model_bridge._megatron_local_name_to_global(
models: megatron.core.transformer.module.MegatronModule | List[megatron.core.transformer.module.MegatronModule],
config: megatron.core.transformer.transformer_config.TransformerConfig,
param_name: str,
vp_stage: Optional[int] = None,
) str#

Adjust layer number and expert number from local to global numbering.

class bridge.models.conversion.model_bridge.MegatronModelBridge#

Bases: typing.Generic[bridge.models.conversion.model_bridge.HFPreTrained, bridge.models.conversion.model_bridge.ModelProviderTarget, bridge.models.conversion.model_bridge.MegatronModel]

High-level orchestrator for HuggingFace ↔ Megatron model conversions.

This abstract base class provides the framework for converting models between HuggingFace and Megatron formats. It acts as an orchestrator that coordinates the conversion process without directly handling the complex details of tensor parallelism or weight transformations.

The bridge pattern separates concerns:

  • MegatronModelBridge: Orchestrates the overall conversion process

  • MegatronMappingRegistry: Manages parameter name mappings

  • MegatronParamMapping: Handles actual weight transformations and distribution

Key responsibilities:

  1. Build conversion tasks that map each parameter to its appropriate bridge

  2. Execute tasks with proper error handling and progress tracking

  3. Provide utilities for configuration translation

  4. Handle virtual pipeline parallelism (VP) complexities

To implement a bridge for a new model architecture:

  1. Create a subclass decorated with @MegatronModelBridge.register_bridge:

    .. code-block:: python

     @MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel)
     class MegatronCausalLlamaBridge(MegatronModelBridge):
         pass
    
  2. Implement provider_bridge to create Megatron configurations:

    .. code-block:: python

     def provider_bridge(self, hf_pretrained) -> LlamaModelProvider:
         return LlamaModelProvider(
             num_layers=hf_pretrained.config.num_hidden_layers,
             hidden_size=hf_pretrained.config.hidden_size,
             ...
         )
    
  3. Implement mapping_registry to define weight mappings:

    .. code-block:: python

     def mapping_registry(self) -> MegatronMappingRegistry:
         return MegatronMappingRegistry(
             AutoMapping(
                 megatron_param="embedding.word_embeddings.weight",
                 hf_param="model.embed_tokens.weight"
             ),
             ...
         )
    

.. rubric:: Example

.. code-block:: python

# The bridge is typically not instantiated directly
# Instead, use AutoBridge or AutoBridge which handle this
bridge = AutoBridge.from_hf_pretrained("meta-llama/Llama-3-8B")
provider = bridge.to_megatron_provider()

.. note::

This class uses generic type parameters to ensure type safety:

  • HFPreTrained: The HuggingFace model type

  • ModelProviderTarget: The Megatron model provider type

  • MegatronModel: The Megatron model type

abstractmethod provider_bridge(
hf_pretrained: bridge.models.conversion.model_bridge.HFPreTrained,
) bridge.models.conversion.model_bridge.ModelProviderTarget#

Create a Megatron model provider from HuggingFace configuration.

This abstract method must be implemented by subclasses to translate HuggingFace model configurations into Megatron model provider instances. The provider contains all necessary configuration for creating Megatron models.

Parameters:

hf_pretrained (HFPreTrained) – HuggingFace model or configuration containing the source model’s architecture details.

Returns:

A configured model provider instance (e.g., GPTModelProvider, LlamaModelProvider) ready to create Megatron models.

Return type:

ModelProviderTarget

.. rubric:: Example

.. code-block:: python

def provider_bridge(self, hf_pretrained):
    return LlamaModelProvider(
        num_layers=hf_pretrained.config.num_hidden_layers,
        hidden_size=hf_pretrained.config.hidden_size,
        num_attention_heads=hf_pretrained.config.num_attention_heads,
        ffn_hidden_size=hf_pretrained.config.intermediate_size,
        # ... other configuration mappings
    )
abstractmethod mapping_registry() megatron.bridge.models.conversion.mapping_registry.MegatronMappingRegistry#

Define weight mappings between HuggingFace and Megatron formats.

This abstract method must be implemented by subclasses to specify how parameters map between the two formats. The returned MegatronMappingRegistry contains all param mappings needed for the model architecture.

Returns:

MegatronMappingRegistry containing all weight mapping definitions.

Return type:

MegatronMappingRegistry

.. rubric:: Example

.. code-block:: python

def mapping_registry(self):
    return MegatronMappingRegistry(
        AutoMapping(
            megatron_param="embedding.word_embeddings.weight",
            hf_param="model.embed_tokens.weight"
        ),
        QKVMapping(
            megatron_param="decoder.layers.*.self_attention.linear_qkv.weight",
            q="model.layers.*.self_attn.q_proj.weight",
            k="model.layers.*.self_attn.k_proj.weight",
            v="model.layers.*.self_attn.v_proj.weight"
        ),
        # ... more param mappings
    )
_megatron_global_param_names_all_pp_ranks(
megatron_model: Union[bridge.models.conversion.model_bridge.MegatronModel, List[bridge.models.conversion.model_bridge.MegatronModel]],
) List[str]#

Get all parameter names across all pipeline parallel ranks.

_with_progress_tracking(
tasks,
description: str,
show_progress: bool = True,
)#

Helper method to wrap an iterable with progress tracking.

Parameters:
  • tasks – Iterable of tasks to process

  • description – Description for the progress bar

  • show_progress – Whether to show progress (defaults to True)

Yields:

Items from the tasks iterable while updating progress

load_weights_hf_to_megatron(
hf_pretrained: bridge.models.conversion.model_bridge.HFPreTrained,
megatron_model: Union[bridge.models.conversion.model_bridge.MegatronModel, List[bridge.models.conversion.model_bridge.MegatronModel]],
) List[bridge.models.conversion.model_bridge.MegatronModel]#

Load HuggingFace weights into Megatron models.

This method orchestrates the complete weight loading process from HuggingFace format to Megatron’s distributed format. It builds a conversion task and executes it with proper progress tracking and error handling.

The actual weight transformations and distribution are delegated to the appropriate MegatronParamMapping instances based on the state mappings.

Parameters:
  • hf_pretrained (HFPreTrained) – HuggingFace model or state source containing the weights to load.

  • megatron_model (Union[MegatronModel, List[MegatronModel]]) – Megatron model instance or list of model instances (one per virtual pipeline stage).

Returns:

The input megatron_model as a list with loaded weights.

Return type:

List[MegatronModel]

Process:

  1. Build a task mapping each Megatron parameter to its source

  2. For each parameter in the task:

    • Fetch source weights from HuggingFace state

    • Apply format transformation via the param mapping

    • Distribute to appropriate TP/PP ranks

    • Copy into the Megatron parameter

.. rubric:: Example

.. code-block:: python

hf_model = PreTrainedCausalLM.from_pretrained("gpt2")
megatron_model = create_megatron_model()  # Single model or list
bridge.load_weights_hf_to_megatron(hf_model, megatron_model)

.. note::

Progress is shown only on rank 0 to avoid cluttered output in distributed environments.

Raises:
  • ValueError – If hf_pretrained doesn’t have state attribute or if weight shapes don’t match.

  • AttributeError – If required HF weights are missing.

stream_weights_hf_to_megatron(
hf_pretrained: bridge.models.conversion.model_bridge.HFPreTrained,
megatron_model: Union[bridge.models.conversion.model_bridge.MegatronModel, List[bridge.models.conversion.model_bridge.MegatronModel]],
conversion_tasks: Optional[List[bridge.models.conversion.model_bridge.WeightConversionTask]] = None,
) Iterable[bridge.models.conversion.model_bridge.MegatronWeightTuple]#

Generator variant of load_weights_hf_to_megatron for streaming weight conversion.

This method provides a memory-efficient way to convert weights by yielding them one at a time instead of loading all at once. Useful for processing very large models or when implementing custom weight handling logic.

Parameters:
  • hf_pretrained (HFPreTrained) – HuggingFace model or state source containing the weights.

  • megatron_model (Union[MegatronModel, List[MegatronModel]]) – Megatron model instance or list of model instances to extract configuration from.

  • conversion_tasks (Optional[List[WeightConversionTask]]) – Pre-built conversion tasks. If not provided, tasks will be built automatically from the models.

Yields:

MegatronWeightTuple

Named tuples containing: - vp_stage: Index of the model in megatron_model list - param_name: Name of the parameter - weight: Transformed weight tensor for this rank

.. rubric:: Example

.. code-block:: python

# Process weights one by one
for weight_tuple in bridge.stream_weights_hf_to_megatron(hf_model, megatron_model):
    print(f"Processing {weight_tuple.param_name}: {weight_tuple.weight.shape}")
    # Custom processing logic here

# Or use pre-built conversion tasks
tasks = bridge.build_conversion_tasks(hf_model, megatron_model)
for weight_tuple in bridge.stream_weights_hf_to_megatron(hf_model, megatron_model, tasks):
    print(f"Processing {weight_tuple.param_name}: {weight_tuple.weight.shape}")

.. note:: Only yields weights that belong to the current rank after TP/PP distribution.

Raises:

ValueError – If input parameters are invalid.

stream_weights_megatron_to_hf(
megatron_model: Union[bridge.models.conversion.model_bridge.MegatronModel, List[bridge.models.conversion.model_bridge.MegatronModel]],
hf_pretrained: bridge.models.conversion.model_bridge.HFPreTrained,
cpu: bool = True,
show_progress: bool = True,
conversion_tasks: Optional[List[bridge.models.conversion.model_bridge.WeightConversionTask]] = None,
) Iterable[bridge.models.conversion.model_bridge.HFWeightTuple]#

Export Megatron weights to HuggingFace format.

This method orchestrates the conversion of weights from Megatron’s distributed format back to HuggingFace format. It handles gathering from tensor parallel ranks, broadcasting across pipeline parallel ranks, and format conversions. All ranks receive the full tensors.

The export order is determined automatically:

  • First tries safetensors order (if key_to_filename_map is available)

  • Falls back to HuggingFace state dict order

Parameters:
  • megatron_model (Union[MegatronModel, List[MegatronModel]]) – Megatron model instance or list of model instances (one per virtual pipeline stage).

  • hf_pretrained (HFPreTrained) – HuggingFace model/config for metadata and mapping info.

  • cpu (bool, optional) – Whether to move tensors to CPU before yielding. Defaults to True.

  • show_progress (bool, optional) – Display progress bar during export. Defaults to True.

  • conversion_tasks (Optional[List[WeightConversionTask]]) – Pre-built conversion tasks. If not provided, tasks will be built automatically from the models.

Yields:

HFWeightTuple – Named tuples of (param_name, weight_tensor) in HF format.

.. rubric:: Example

.. code-block:: python

# Export weights
for name, weight in bridge.stream_weights_megatron_to_hf(megatron_model, hf_config):
    print(f"Exported {name}: {weight.shape}")

# Or use pre-built conversion tasks
tasks = bridge.build_conversion_tasks(hf_config, megatron_model)
for name, weight in bridge.stream_weights_megatron_to_hf(
    megatron_model, hf_config, conversion_tasks=tasks
):
    print(f"Exported {name}: {weight.shape}")
Raises:

ValueError – If input parameters are invalid.

.. note:: All ranks yield the full tensors after gathering from distributed format.

dtype_from_hf(config, default=None)#

Extract torch dtype from a HuggingFace config.

This utility method handles the conversion of dtype specifications in HuggingFace configs to PyTorch dtype objects. Supports both direct torch.dtype objects and string representations.

Parameters:
  • config – HuggingFace configuration object with a torch_dtype attribute.

  • default (Any, optional) – Default value to return if torch_dtype is not str or torch.dtype. Defaults to None.

Returns:

The corresponding PyTorch dtype.

Return type:

torch.dtype

Raises:
  • AssertionError – If config doesn’t have torch_dtype attribute.

  • ValueError – If torch_dtype is neither a string nor torch.dtype.

.. rubric:: Example

.. code-block:: python

dtype = bridge.dtype_from_hf(hf_config)
print(dtype)  # torch.float16
dtype_from_str(dtype: str) torch.dtype#

Convert a string precision identifier to equivalent torch dtype.

This utility method handles various string representations of PyTorch data types, including common abbreviations and mixed precision formats.

Parameters:

dtype (str) – String representation of dtype (e.g., “float16”, “fp16”, “bf16-mixed”).

Returns:

Corresponding PyTorch dtype (defaults to float32 if unknown).

Return type:

torch.dtype

Supported formats: - float16/fp16/16/16-mixed → torch.float16 - bfloat16/bf16-mixed → torch.bfloat16 - Others → torch.float32 (default)

.. rubric:: Example

.. code-block:: python

dtype = bridge.dtype_from_str("fp16")
print(dtype)  # torch.float16

dtype = bridge.dtype_from_str("bf16-mixed")
print(dtype)  # torch.bfloat16
make_vocab_size_divisible_by(vocab_size: int) int#

Calculate an appropriate divisor for vocabulary size padding.

Megatron requires vocabulary sizes to be divisible by certain values for efficient tensor parallelism. This method finds the largest power of 2 (up to 128) that evenly divides the vocabulary size.

Parameters:

vocab_size (int) – Original vocabulary size from the model.

Returns:

Largest power of 2 (≤ 128) that divides vocab_size.

Return type:

int

.. rubric:: Example

.. code-block:: python

# For vocab_size=50257 (GPT-2)
divisor = bridge.make_vocab_size_divisible_by(50257)
print(divisor)  # 1 (50257 is prime)

# For vocab_size=32000 (Llama)
divisor = bridge.make_vocab_size_divisible_by(32000)
print(divisor)  # 128

.. note::

The returned value is used by Megatron to potentially pad the vocabulary to ensure efficient parallelization.

_get_provider_from_model(
model: megatron.core.transformer.module.MegatronModule,
) bridge.models.conversion.model_bridge.ModelProviderTarget#

Extract provider/config from model.

_unwrap_name(name: str) str#

Unwrap name from DDP or other wrappers.

Parameters:

name – Parameter name that may have ‘module.’ prefixes

Returns:

Unwrapped parameter name with ‘module.’ prefixes removed

.. rubric:: Example

‘module.module.decoder.weight’ -> ‘decoder.weight’

_broadcast_shared_embeddings(
megatron_model: Union[bridge.models.conversion.model_bridge.MegatronModel, List[bridge.models.conversion.model_bridge.MegatronModel]],
) None#

Broadcast shared embeddings and output weights across embedding group.

When embeddings and output weights are shared and pipeline parallelism is enabled, this method ensures all ranks in the embedding group have the same weights by broadcasting from rank 0.

Parameters:

megatron_model – Megatron model instance or list of model instances.

build_conversion_tasks(
hf_pretrained: bridge.models.conversion.model_bridge.HFPreTrained,
megatron_model: List[bridge.models.conversion.model_bridge.MegatronModel],
) List[None | bridge.models.conversion.model_bridge.WeightConversionTask]#

Construct the conversion tasks between HF and megatron.

The algorithm walks over every parameter of every destination model, asks the :class:MegatronMappingRegistry whether it has a mapping for that parameter, and – if the corresponding HF weights actually exist – yields an :class:_HFLoadTask describing exactly how that parameter will be populated.

classmethod register_bridge(
*,
source: Type[transformers.modeling_utils.PreTrainedModel],
target: Type[bridge.models.conversion.model_bridge.MegatronModel],
) Callable[[bridge.models.conversion.model_bridge._BridgeImplClass], bridge.models.conversion.model_bridge._BridgeImplClass]#

Class decorator for registering bridge implementations.

This decorator registers a MegatronModelBridge subclass with the dispatch system, enabling automatic routing of conversions based on the source HuggingFace model type and target Megatron model type.

Parameters:
  • source (Type[PreTrainedModel]) – HuggingFace PreTrainedModel class (e.g., LlamaForCausalLM).

  • target (Type[MegatronModel]) – Megatron model class (e.g., GPTModel).

Returns:

Decorator function that registers the bridge implementation.

Return type:

Callable[[_BridgeImplClass], _BridgeImplClass]

.. rubric:: Example

.. code-block:: python

@MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel)
class MegatronCausalLlamaBridge(MegatronModelBridge):
    def provider_bridge(self, hf_pretrained):
        # Implementation
        pass

    def mapping_registry(self):
        # Implementation
        pass

.. note::

The decorated class is registered with multiple dispatchers to handle different conversion scenarios. The registration is automatic when the class is defined.

bridge.models.conversion.model_bridge.is_tensor_parallel(param) bool#

Check if a parameter is tensor parallel distributed.

bridge.models.conversion.model_bridge.get_model_bridge(
hf_architecture,
) bridge.models.conversion.model_bridge.MegatronModelBridge#

Get the appropriate model bridge for a given HuggingFace architecture.

bridge.models.conversion.model_bridge.stream_weights_megatron_to_hf(
dispatch_instance: bridge.models.conversion.model_bridge.MegatronModel,
megatron_model: Union[bridge.models.conversion.model_bridge.MegatronModel, List[bridge.models.conversion.model_bridge.MegatronModel]],
hf_pretrained: bridge.models.conversion.model_bridge.HFPreTrained,
cpu: bool = True,
show_progress: bool = True,
conversion_tasks: Optional[List[bridge.models.conversion.model_bridge.WeightConversionTask]] = None,
) Iterable[bridge.models.conversion.model_bridge.HFWeightTuple]#

Bridge Megatron model state to HuggingFace format.

bridge.models.conversion.model_bridge.register_bridge_implementation(
*,
source: Type[transformers.modeling_utils.PreTrainedModel],
target: Type[megatron.core.transformer.module.MegatronModule],
bridge_class: Type[bridge.models.conversion.model_bridge.MegatronModelBridge],
) None#

Register a bridge implementation with the dispatch system.

Parameters:
  • source – HuggingFace PreTrainedModel class (e.g., LlamaForCausalLM)

  • target – Megatron model class (e.g., GPTModel)

  • bridge_class – MegatronModelBridge implementation class

bridge.models.conversion.model_bridge.create_bridge_decorator(
*,
source: Type[transformers.modeling_utils.PreTrainedModel],
target: Type[megatron.core.transformer.module.MegatronModule],
) Callable[[Type[bridge.models.conversion.model_bridge.MegatronModelBridge]], Type[bridge.models.conversion.model_bridge.MegatronModelBridge]]#

Create a decorator for registering bridge implementations.

Parameters:
  • source – HuggingFace PreTrainedModel class

  • target – Megatron model class

Returns:

Decorator function that registers the bridge implementation