bridge.models.conversion.param_mapping#

Module Contents#

Classes#

MegatronParamMapping

Abstract base class for weight conversion between Megatron and external formats.

DirectMapping

Direct 1:1 weight mapping with no transformation or tensor parallelism.

ColumnParallelMapping

Mapping for column-parallel linear and embedding weights.

RowParallelMapping

Mapping for row-parallel linear weights.

ReplicatedMapping

Mapping for weights that are fully replicated across TP ranks.

AutoMapping

Smart mapping that automatically detects and applies the correct parallelism strategy.

QKVMapping

Mapping for interleaved Query/Key/Value attention projection weights.

GatedMLPMapping

Mapping for gated-MLP projection weights (SwiGLU / GeGLU).

Functions#

merge_qkv_biases

Merge separate Q, K, V bias vectors into Megatron’s interleaved QKV format.

split_qkv_biases

Split Megatron’s interleaved QKV bias into separate Q, K, V biases.

merge_qkv_weights

Merge separate Q, K, V weight matrices into Megatron’s interleaved QKV format.

split_qkv_weights

Split Megatron’s interleaved QKV tensor into separate Q, K, V matrices.

Data#

API#

bridge.models.conversion.param_mapping.WeightType#

‘TypeVar(…)’

class bridge.models.conversion.param_mapping.MegatronParamMapping(
megatron_param: str,
hf_param: Union[str, Dict[str, str]],
)#

Bases: abc.ABC, typing.Generic[bridge.models.conversion.param_mapping.WeightType]

Abstract base class for weight conversion between Megatron and external formats.

This class provides the foundation for all weight mappings, handling the complex conversions between Megatron-Core’s distributed tensor formats and standard (typically HuggingFace) formats. Each concrete mapping implements specific transformation logic while inheriting common parallel communication patterns.

Key responsibilities:

  • Format transformation (e.g., QKV merging/splitting, gated MLP handling)

  • Tensor parallel (TP) distribution and gathering across GPUs

  • Pipeline parallel (PP) broadcasting between pipeline stages

  • Wildcard pattern resolution for layer-wise mappings

The mapping abstraction ensures that higher-level code doesn’t need to know about the parallel topology or format differences - it just requests a conversion and the mapping handles all the complexity.

Public helper methods for subclasses:

  • broadcast_from_pp_rank: Broadcast tensors across pipeline stages

  • broadcast_obj_from_pp_rank: Broadcast Python objects across PP ranks

  • broadcast_tensor_to_tp_ranks: Broadcast within TP group

  • scatter_to_tp_ranks: Distribute tensor shards to TP ranks

  • gather_from_tp_ranks: Collect tensor shards from TP ranks

.. rubric:: Example

.. code-block:: python

class MyCustomMapping(MegatronParamMapping[torch.Tensor]):
    def hf_to_megatron(self, hf_weights, megatron_module):
        # Custom transformation logic
        transformed = hf_weights.t()  # Example: transpose
        # Use helpers for distribution
        return self.scatter_to_tp_ranks(...)

    def megatron_to_hf(self, megatron_weights, megatron_module):
        # Broadcast from owning PP rank
        weight = self.broadcast_from_pp_rank(megatron_weights)
        # Gather from TP ranks and transform
        gathered = self.gather_from_tp_ranks(weight)
        return {"custom_weight": gathered[0].t()}

Initialization

Initialize the weight mapping.

Parameters:
  • megatron_param (str) – Megatron parameter name pattern (supports * wildcards).

  • hf_param (Union[str, Dict[str, str]]) – External format name pattern(s).

property tp_group#

Get the tensor model parallel group.

property tp_rank: int#

Get the tensor model parallel rank.

property tp_size: int#

Get the tensor model parallel size.

property pp_rank: int#

Get the pipeline model parallel rank.

property pp_size: int#

Get the pipeline model parallel size.

property ep_rank: int#

Get the expert model parallel rank.

property ep_size: int#

Get the expert model parallel size.

property etp_rank: int#

Get the expert tensor parallel rank.

property etp_size: int#

Get the expert tensor parallel size.

property is_expert: bool#

Check if this mapping is for an expert parameter.

_resolve_names(
captures: Tuple[str, ...],
) Tuple[str, Union[str, Dict[str, str]]]#
resolve(
captures: Tuple[str, ...],
) bridge.models.conversion.param_mapping.MegatronParamMapping#

Create a new mapping with resolved wildcards.

This default implementation works for mappings with a (megatron_param, hf_param) constructor.

Parameters:

captures (Tuple[str, ...]) – Captured wildcard values.

Returns:

A new mapping instance with resolved names.

Return type:

MegatronParamMapping

abstractmethod hf_to_megatron(
hf_weights: bridge.models.conversion.param_mapping.WeightType,
megatron_module: torch.nn.Module,
) torch.Tensor#

Convert hf_weights TO Megatron format.

This method handles:

  1. Format transformation (if needed)

  2. Tensor parallel distribution (if self.tp_size > 1)

Parameters:
  • hf_weights (WeightType) – Source hf_weights in external format.

  • megatron_module (nn.Module) – Target Megatron module (for config access).

Returns:

Weight tensor ready for the current TP rank.

Return type:

torch.Tensor

abstractmethod megatron_to_hf(
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[torch.nn.Module],
) Dict[str, torch.Tensor]#

Convert weights FROM Megatron format.

This method handles:

  1. Pipeline parallel broadcasting (if weight is on different PP rank)

  2. Tensor parallel gathering (if needed)

  3. Format transformation

Parameters:
  • megatron_weights (Optional[torch.Tensor]) – Weight tensor from current rank (None if on different PP rank).

  • megatron_module (Optional[nn.Module]) – Module for config access (None if on different PP rank).

Returns:

Converted weights (empty dict if not on TP rank 0).

Return type:

Dict[str, torch.Tensor]

broadcast_from_pp_rank(
tensor: Optional[torch.Tensor],
) Optional[torch.Tensor]#

Broadcast a tensor from the pipeline-parallel rank that owns it.

Broadcasts to all PP ranks. This mirrors the behaviour of broadcast_from_megatron_pp in the original MMapping implementation and additionally keeps the tensor-parallel metadata (tensor_model_parallel, partition_dim) consistent on every rank.

Parameters:

tensor (Optional[torch.Tensor]) – The local tensor if the current PP rank owns it. None otherwise.

Returns:

The broadcasted tensor on every PP rank, or None if no PP rank owned the tensor (which indicates a bug in the calling code).

Return type:

Optional[torch.Tensor]

broadcast_obj_from_pp_rank(
obj: Optional[Any],
) Any#

Broadcast any Python object from the PP rank that owns it.

This method is useful for broadcasting configuration objects or other metadata across pipeline parallel ranks.

Parameters:

obj (Optional[Any]) – Object to broadcast (None on non-owning ranks).

Returns:

Broadcasted object on all ranks.

Return type:

Any

Raises:

ValueError – If object exists on multiple ranks or no ranks.

broadcast_tensor_to_tp_ranks(
tensor: torch.Tensor,
src_rank: int = 0,
) torch.Tensor#

Broadcast a tensor to all TP ranks.

Parameters:
  • tensor (torch.Tensor) – The tensor to broadcast.

  • src_rank (int, optional) – The source rank within the TP group. Defaults to 0.

Returns:

The broadcasted tensor.

Return type:

torch.Tensor

scatter_to_tp_ranks(
splits: Optional[List[torch.Tensor]],
output_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
src_rank: int = 0,
) torch.Tensor#

Scatter tensor splits to TP ranks.

Parameters:
  • splits (Optional[List[torch.Tensor]]) – A list of tensor shards to scatter. Only rank src_rank needs this.

  • output_shape (torch.Size) – The shape of the output tensor on each rank.

  • dtype (torch.dtype) – The data type of the output tensor.

  • device (torch.device) – The device for the output tensor.

  • src_rank (int, optional) – The source rank for the scatter operation. Defaults to 0.

Returns:

The scattered tensor shard on the current rank.

Return type:

torch.Tensor

gather_from_tp_ranks(
tensor: torch.Tensor,
) List[torch.Tensor]#

Gather tensors from all TP ranks.

Parameters:

tensor (torch.Tensor) – The tensor shard to be gathered from the current rank.

Returns:

A list of tensor shards from all TP ranks.

Return type:

List[torch.Tensor]

_validate_patterns()#

Validate wildcard consistency between patterns.

_normalize_expert_param_name(param_name: str) str#

Normalize expert parameter name by replacing trailing numbers with 0. e.g. experts.weight15 -> experts.weight0, experts.bias15 -> experts.bias0

Parameters:

param_name (str) – Parameter name that may end with a number.

Returns:

Parameter name with trailing number replaced by 0.

Return type:

str

_get_config(module: torch.nn.Module) Any#

Extract configuration from module hierarchy.

gather_from_ep_ranks(
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[megatron.core.transformer.module.MegatronModule],
hf_param_name: Optional[str],
) Dict[str, torch.Tensor]#

Handle expert parallel weight gathering for MoE models.

This method handles the gathering of expert weights across expert parallel ranks. It should only be called when the parameter is confirmed to be an expert weight. For example, with expert parallel size = 2 and 8 total experts, experts are distributed as: Rank 0: [0, 1, 2, 3], Rank 1: [4, 5, 6, 7] This will return a dictionary with one entry per EP rank, e.g. {0: weight0, 4: weight4}

Parameters:
  • megatron_weights (Optional[torch.Tensor]) – The local expert weight tensor.

  • megatron_module (Optional[MegatronModule]) – The megatron module containing config.

Returns:

Dictionary of expert weights mapped to HF parameter names.

Return type:

Dict[str, torch.Tensor]

class bridge.models.conversion.param_mapping.DirectMapping#

Bases: bridge.models.conversion.param_mapping.MegatronParamMapping[torch.Tensor]

Direct 1:1 weight mapping with no transformation or tensor parallelism.

hf_to_megatron(
hf_weights: torch.Tensor,
megatron_module: torch.nn.Module,
) torch.Tensor#

Direct copy - no transformation or distribution.

megatron_to_hf(
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[torch.nn.Module],
) Dict[str, torch.Tensor]#

Direct copy with PP broadcast.

class bridge.models.conversion.param_mapping.ColumnParallelMapping#

Bases: bridge.models.conversion.param_mapping.MegatronParamMapping[torch.Tensor]

Mapping for column-parallel linear and embedding weights.

Column-parallel layers in Megatron split the output dimension across tensor parallel ranks. This is used for layers where each rank computes a portion of the output features independently, such as:

  • Embedding layers (split vocabulary)

  • Linear layers producing hidden states (e.g., QKV projections, MLP up projections)

The weight matrix is partitioned along dimension 0 (rows), so each TP rank holds a subset of output features while maintaining all input features.

Sharding pattern

  • Original weight: [output_features, input_features]

  • Rank 0: [output_features/tp_size, input_features]

  • Rank 1: [output_features/tp_size, input_features]

  • …

Forward path (HuggingFace → Megatron)

  1. Validate divisibility: output dimension must be divisible by tp_size

  2. Split: Chunk tensor along dim 0 into tp_size equal parts

  3. Scatter: Distribute chunks to respective TP ranks

Reverse path (Megatron → HuggingFace)

  1. Broadcast: Ensure all PP ranks have the tensor

  2. Gather: Collect chunks from all TP ranks

  3. Concatenate: Reassemble along dim 0 on rank 0

.. rubric:: Example

.. code-block:: python

# For a weight of shape [4096, 1024] with tp_size=4:
# Each rank gets [1024, 1024] after column-parallel split
mapping = ColumnParallelMapping("linear.weight", "transformer.linear.weight")
megatron_weights = mapping.hf_to_megatron(hf_weight, megatron_module)
# megatron_weights.shape = [1024, 1024] on each rank

.. note::

This mapping also handles bias terms, which are 1D tensors split along their only dimension following the same pattern.

hf_to_megatron(
hf_weights: torch.Tensor,
megatron_module: torch.nn.Module,
) torch.Tensor#

Split weight along dim 0 and distribute to TP ranks.

megatron_to_hf(
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[torch.nn.Module],
) Dict[str, torch.Tensor]#

Gather from all TP ranks and concatenate.

class bridge.models.conversion.param_mapping.RowParallelMapping#

Bases: bridge.models.conversion.param_mapping.MegatronParamMapping[torch.Tensor]

Mapping for row-parallel linear weights.

Megatron shards row-parallel tensors along dimension 1 (the input dimension of a linear layer).

Forward path (external → Megatron)

  1. Rank 0 validates that the second dimension is divisible by tp_size.

  2. Rank 0 splits the tensor with torch.chunk(..., dim=1) producing tp_size equally-sized shards.

  3. The shards are scattered so that every TP rank receives exactly one shard matching the shape of its local Megatron parameter.

Reverse path (Megatron → external)

  1. The local Megatron parameter (which may live on any PP rank) is broadcast to all PP ranks so that the gather step can be collective.

  2. All TP ranks gather their shard.

  3. Rank 0 concatenates the gathered list along dim 1 to reconstruct the original unsharded weight and emits it under the external (HF) name.

hf_to_megatron(
hf_weights: torch.Tensor,
megatron_module: torch.nn.Module,
) torch.Tensor#

Split weight along dim 1 and distribute to TP ranks.

megatron_to_hf(
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[torch.nn.Module],
) Dict[str, torch.Tensor]#

Gather from all TP ranks and concatenate.

class bridge.models.conversion.param_mapping.ReplicatedMapping#

Bases: bridge.models.conversion.param_mapping.MegatronParamMapping[torch.Tensor]

Mapping for weights that are fully replicated across TP ranks.

Examples: layer-norm scales, biases, router weights in MoE, etc.

These tensors exist in exactly the same form on every TP rank, so the mapping logic is trivial – but we still need to broadcast across TP ranks during load (HF → Megatron) and ensure we do not emit duplicates during export (Megatron → HF).

hf_to_megatron(
hf_weights: torch.Tensor,
megatron_module: torch.nn.Module,
) torch.Tensor#

Replicate weight to all TP ranks.

megatron_to_hf(
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[torch.nn.Module],
) Dict[str, torch.Tensor]#

Return weight only from rank 0 to avoid duplication.

class bridge.models.conversion.param_mapping.AutoMapping(megatron_param: str, hf_param: str)#

Bases: bridge.models.conversion.param_mapping.MegatronParamMapping[torch.Tensor]

Smart mapping that automatically detects and applies the correct parallelism strategy.

This mapping eliminates the need to manually specify whether a layer is column-parallel, row-parallel, or replicated. It examines the Megatron module at runtime and delegates to the appropriate specialized mapping.

Detection strategy

  1. Check module class name against a registry of known types

  2. If unknown, examine module attributes (tensor_model_parallel, partition_dim)

  3. Delegate to appropriate mapping: ColumnParallel, RowParallel, or Replicated

This abstraction is particularly useful for model-agnostic code where you don’t know the parallelism type ahead of time, or when working with models that mix different parallelism strategies.

Built-in module recognition

  • Column-parallel: ColumnParallelLinear, VocabParallelEmbedding, etc.

  • Row-parallel: RowParallelLinear, TERowParallelLinear

  • Replicated: LayerNorm, RMSNorm, and other normalization layers

.. rubric:: Example

.. code-block:: python

# Automatically handles any weight type
mapping = AutoMapping(
    megatron_param="decoder.layers.*.mlp.linear_fc1.weight",
    hf_param="model.layers.*.mlp.gate_proj.weight"
)

# Works with column-parallel layers
megatron_weights = mapping.hf_to_megatron(hf_weight, column_parallel_module)

# Also works with normalization layers
norm_weight = mapping.hf_to_megatron(hf_norm, layer_norm_module)

# Register custom module types
AutoMapping.register_module_type("MyCustomLinear", "column")

.. note::

If the parallelism type cannot be determined, the mapping will raise a descriptive error suggesting how to fix the issue.

Initialization

Initialize TP-aware mapping.

_MODULE_TYPE_REGISTRY: Dict[str, set]#

None

classmethod register_module_type(module_name: str, parallelism_type: str)#

Register a new module type for automatic parallelism detection.

Parameters:
  • module_name (str) – The name of the module class (e.g., ‘MyColumnLinear’).

  • parallelism_type (str) – One of ‘column’, ‘row’, or ‘replicated’.

_get_or_create_mapping(
parallelism_type: str,
) bridge.models.conversion.param_mapping.MegatronParamMapping[torch.Tensor]#

Get or create the appropriate mapping for the given type.

_detect_parallelism_type(module: torch.nn.Module) str#

Detect parallelism type from module.

hf_to_megatron(
hf_weights: torch.Tensor,
megatron_module: torch.nn.Module,
) torch.Tensor#

Delegate to appropriate mapping based on module type.

megatron_to_hf(
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[torch.nn.Module],
) Dict[str, torch.Tensor]#

Delegate to appropriate mapping based on module type.

class bridge.models.conversion.param_mapping.QKVMapping(megatron_param: str, q: str, k: str, v: str)#

Bases: bridge.models.conversion.param_mapping.MegatronParamMapping[typing.Dict[str, torch.Tensor]]

Mapping for interleaved Query/Key/Value attention projection weights.

This mapping handles the conversion between separate Q, K, V matrices used in standard transformers and Megatron’s optimized interleaved format. The interleaving pattern groups queries with their corresponding key-value pairs to maximize GEMM efficiency during attention computation.

External format (HuggingFace)

  • Separate tensors: q_proj, k_proj, v_proj

  • Each of shape [hidden_size, hidden_size] or [hidden_size, head_dim * num_heads]

Megatron format

  • Single interleaved tensor following grouped query attention (GQA) pattern

  • Interleaving order: [q1...qn, k1, v1, q1...qn, k2, v2, ...]

  • Where n = num_attention_heads / num_query_groups

Key features

  1. Format conversion: Handles merging/splitting with proper interleaving

  2. Grouped Query Attention: Supports different numbers of Q and KV heads

  3. Tensor parallelism: Delegates to AutoMapping for distribution

.. rubric:: Example

.. code-block:: python

# Create mapping for attention weights
mapping = QKVMapping(
    megatron_param="decoder.layers.*.self_attention.linear_qkv.weight",
    q="model.layers.*.self_attn.q_proj.weight",
    k="model.layers.*.self_attn.k_proj.weight",
    v="model.layers.*.self_attn.v_proj.weight"
)

# Convert from HuggingFace to Megatron
qkv_weights = {"q": q_tensor, "k": k_tensor, "v": v_tensor}
megatron_qkv = mapping.hf_to_megatron(qkv_weights, megatron_module)

# Convert from Megatron to HuggingFace
hf_weights = mapping.megatron_to_hf(megatron_qkv, megatron_module)
# Returns: {"q_proj.weight": ..., "k_proj.weight": ..., "v_proj.weight": ...}

.. note::

This mapping automatically handles both regular multi-head attention (same number of Q, K, V heads) and grouped query attention (fewer KV heads than Q heads) based on the model configuration.

Initialization

Initialize QKV mapping.

Parameters:
  • megatron_param (str) – Megatron QKV parameter name pattern.

  • q (str) – Query weight name pattern.

  • k (str) – Key weight name pattern.

  • v (str) – Value weight name pattern.

hf_to_megatron(
hf_weights: Dict[str, torch.Tensor],
megatron_module: torch.nn.Module,
) torch.Tensor#

Merge Q, K, V into interleaved format and distribute.

megatron_to_hf(
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[torch.nn.Module],
) Dict[str, torch.Tensor]#

Gather QKV shards and split into Q, K, V.

resolve(
captures: Tuple[str, ...],
) bridge.models.conversion.param_mapping.MegatronParamMapping#

Return a new resolved QKVMapping instance.

class bridge.models.conversion.param_mapping.GatedMLPMapping(megatron_param: str, gate: str, up: str)#

Bases: bridge.models.conversion.param_mapping.MegatronParamMapping[typing.Dict[str, torch.Tensor]]

Mapping for gated-MLP projection weights (SwiGLU / GeGLU).

Checkpoint formats expose two independent matrices:

  • G – gate projection

  • U – up projection

Megatron concatenates them row-wise ([G; U]) so that a single GEMM can produce both activations.

Responsibilities handled by this mapping

  1. Concatenate / split – convert between [G; U] (Megatron) and the separate {G, U} matrices (external).

  2. Tensor-parallel distribution – correctly splits gate and up projections separately before concatenating corresponding shards, ensuring each TP rank gets the proper [gate_shard; up_shard] format.

TP Distribution Strategy For tensor parallelism, this mapping:

  • Splits gate and up matrices separately along output dimension (dim 0)

  • Concatenates corresponding shards: [gate_shard_i; up_shard_i] for rank i

  • This ensures each rank’s concatenated tensor matches the expected shape

Initialization

Initialize gated MLP mapping.

Parameters:
  • megatron_param (str) – Megatron MLP parameter name pattern.

  • gate (str) – Gate projection weight name pattern.

  • up (str) – Up projection weight name pattern.

hf_to_megatron(
hf_weights: Dict[str, torch.Tensor],
megatron_module: torch.nn.Module,
) torch.Tensor#

Split gate and up separately, then concatenate corresponding shards.

megatron_to_hf(
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[torch.nn.Module],
) Dict[str, torch.Tensor]#

Gather concatenated shards and split into gate and up.

resolve(
captures: Tuple[str, ...],
) bridge.models.conversion.param_mapping.MegatronParamMapping#

Return a new resolved GatedMLPMapping instance.

bridge.models.conversion.param_mapping.merge_qkv_biases(
config: megatron.core.transformer.transformer_config.TransformerConfig,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) torch.Tensor#

Merge separate Q, K, V bias vectors into Megatron’s interleaved QKV format.

Parameters:
  • config (TransformerConfig) – Transformer configuration.

  • q (torch.Tensor) – Query projection biases [hidden_size].

  • k (torch.Tensor) – Key projection biases [kv_hidden_size].

  • v (torch.Tensor) – Value projection biases [kv_hidden_size].

Returns:

Interleaved QKV biases in Megatron format as 1D tensor.

Return type:

torch.Tensor

bridge.models.conversion.param_mapping.split_qkv_biases(
config: megatron.core.transformer.transformer_config.TransformerConfig,
qkv: torch.Tensor,
) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

Split Megatron’s interleaved QKV bias into separate Q, K, V biases.

Parameters:
  • config (TransformerConfig) – Transformer configuration.

  • qkv (torch.Tensor) – Interleaved QKV biases in Megatron format (1D tensor).

Returns:

Tuple of (Q, K, V) bias vectors.

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

bridge.models.conversion.param_mapping.merge_qkv_weights(
provider: megatron.core.transformer.transformer_config.TransformerConfig,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) torch.Tensor#

Merge separate Q, K, V weight matrices into Megatron’s interleaved QKV format.

Parameters:
  • provider (TransformerConfig) – Model configuration provider.

  • q (torch.Tensor) – Query projection weights [hidden_size, hidden_size] or bias [hidden_size].

  • k (torch.Tensor) – Key projection weights [kv_hidden_size, hidden_size] or bias [kv_hidden_size].

  • v (torch.Tensor) – Value projection weights [kv_hidden_size, hidden_size] or bias [kv_hidden_size].

Returns:

Interleaved QKV weights in Megatron format.

Return type:

torch.Tensor

bridge.models.conversion.param_mapping.split_qkv_weights(
provider: megatron.core.transformer.transformer_config.TransformerConfig,
qkv: torch.Tensor,
) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

Split Megatron’s interleaved QKV tensor into separate Q, K, V matrices.

Parameters:
  • provider (TransformerConfig) – Model configuration provider.

  • qkv (torch.Tensor) – Interleaved QKV weights in Megatron format.

Returns:

Tuple of (Q, K, V) weight matrices.

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]