bridge.models.conversion.param_mapping
#
Module Contents#
Classes#
Abstract base class for weight conversion between Megatron and external formats. |
|
Direct 1:1 weight mapping with no transformation or tensor parallelism. |
|
Mapping for column-parallel linear and embedding weights. |
|
Mapping for row-parallel linear weights. |
|
Mapping for weights that are fully replicated across TP ranks. |
|
Smart mapping that automatically detects and applies the correct parallelism strategy. |
|
Mapping for interleaved Query/Key/Value attention projection weights. |
|
Mapping for gated-MLP projection weights (SwiGLU / GeGLU). |
Functions#
Merge separate Q, K, V bias vectors into Megatron’s interleaved QKV format. |
|
Split Megatron’s interleaved QKV bias into separate Q, K, V biases. |
|
Merge separate Q, K, V weight matrices into Megatron’s interleaved QKV format. |
|
Split Megatron’s interleaved QKV tensor into separate Q, K, V matrices. |
Data#
API#
- bridge.models.conversion.param_mapping.WeightType#
‘TypeVar(…)’
- bridge.models.conversion.param_mapping.logger#
‘getLogger(…)’
- class bridge.models.conversion.param_mapping.MegatronParamMapping(
- megatron_param: str,
- hf_param: Union[str, Dict[str, str]],
Bases:
abc.ABC
,typing.Generic
[bridge.models.conversion.param_mapping.WeightType
]Abstract base class for weight conversion between Megatron and external formats.
This class provides the foundation for all weight mappings, handling the complex conversions between Megatron-Core’s distributed tensor formats and standard (typically HuggingFace) formats. Each concrete mapping implements specific transformation logic while inheriting common parallel communication patterns.
Key responsibilities:
Format transformation (e.g., QKV merging/splitting, gated MLP handling)
Tensor parallel (TP) distribution and gathering across GPUs
Pipeline parallel (PP) broadcasting between pipeline stages
Wildcard pattern resolution for layer-wise mappings
The mapping abstraction ensures that higher-level code doesn’t need to know about the parallel topology or format differences - it just requests a conversion and the mapping handles all the complexity.
Public helper methods for subclasses:
broadcast_from_pp_rank: Broadcast tensors across pipeline stages
broadcast_obj_from_pp_rank: Broadcast Python objects across PP ranks
broadcast_tensor_to_tp_ranks: Broadcast within TP group
scatter_to_tp_ranks: Distribute tensor shards to TP ranks
gather_from_tp_ranks: Collect tensor shards from TP ranks
.. rubric:: Example
.. code-block:: python
class MyCustomMapping(MegatronParamMapping[torch.Tensor]): def hf_to_megatron(self, hf_weights, megatron_module): # Custom transformation logic transformed = hf_weights.t() # Example: transpose # Use helpers for distribution return self.scatter_to_tp_ranks(...) def megatron_to_hf(self, megatron_weights, megatron_module): # Broadcast from owning PP rank weight = self.broadcast_from_pp_rank(megatron_weights) # Gather from TP ranks and transform gathered = self.gather_from_tp_ranks(weight) return {"custom_weight": gathered[0].t()}
Initialization
Initialize the weight mapping.
- Parameters:
megatron_param (str) – Megatron parameter name pattern (supports * wildcards).
hf_param (Union[str, Dict[str, str]]) – External format name pattern(s).
- property tp_group#
Get the tensor model parallel group.
- property tp_rank: int#
Get the tensor model parallel rank.
- property tp_size: int#
Get the tensor model parallel size.
- property pp_rank: int#
Get the pipeline model parallel rank.
- property pp_size: int#
Get the pipeline model parallel size.
- property ep_rank: int#
Get the expert model parallel rank.
- property ep_size: int#
Get the expert model parallel size.
- property etp_rank: int#
Get the expert tensor parallel rank.
- property etp_size: int#
Get the expert tensor parallel size.
- property is_expert: bool#
Check if this mapping is for an expert parameter.
- _resolve_names(
- captures: Tuple[str, ...],
Resolve wildcard patterns with captured values.
Handles both ** (any characters) and * (digits) wildcards in order. ** patterns are processed before * patterns to avoid conflicts.
- resolve(
- captures: Tuple[str, ...],
Create a new mapping with resolved wildcards.
This default implementation works for mappings with a (megatron_param, hf_param) constructor.
- Parameters:
captures (Tuple[str, ...]) – Captured wildcard values.
- Returns:
A new mapping instance with resolved names.
- Return type:
- abstractmethod hf_to_megatron(
- hf_weights: bridge.models.conversion.param_mapping.WeightType,
- megatron_module: torch.nn.Module,
Convert hf_weights TO Megatron format.
This method handles:
Format transformation (if needed)
Tensor parallel distribution (if self.tp_size > 1)
- Parameters:
hf_weights (WeightType) – Source hf_weights in external format.
megatron_module (nn.Module) – Target Megatron module (for config access).
- Returns:
Weight tensor ready for the current TP rank.
- Return type:
torch.Tensor
- abstractmethod megatron_to_hf(
- megatron_weights: Optional[torch.Tensor],
- megatron_module: Optional[torch.nn.Module],
Convert weights FROM Megatron format.
This method handles:
Pipeline parallel broadcasting (if weight is on different PP rank)
Tensor parallel gathering (if needed)
Format transformation
- Parameters:
megatron_weights (Optional[torch.Tensor]) – Weight tensor from current rank (None if on different PP rank).
megatron_module (Optional[nn.Module]) – Module for config access (None if on different PP rank).
- Returns:
Converted weights (empty dict if not on TP rank 0).
- Return type:
Dict[str, torch.Tensor]
- broadcast_from_pp_rank(
- tensor: Optional[torch.Tensor],
- cache_key: Optional[str] = None,
Broadcast a tensor from the pipeline-parallel rank that owns it.
Broadcasts to all PP ranks. This mirrors the behaviour of
broadcast_from_megatron_pp
in the original MMapping implementation and additionally keeps the tensor-parallel metadata (tensor_model_parallel
,partition_dim
) consistent on every rank.- Parameters:
tensor (Optional[torch.Tensor]) – The local tensor if the current PP rank owns it.
None
otherwise.- Returns:
The broadcasted tensor on every PP rank, or
None
if no PP rank owned the tensor (which indicates a bug in the calling code).- Return type:
Optional[torch.Tensor]
- broadcast_obj_from_pp_rank(
- obj: Optional[Any],
- cache_key: Optional[str] = None,
Broadcast any Python object from the PP rank that owns it.
This method is useful for broadcasting configuration objects or other metadata across pipeline parallel ranks. Results are cached after the first call to avoid redundant broadcasts.
- Parameters:
obj (Optional[Any]) – Object to broadcast (None on non-owning ranks).
cache_key (Optional[str]) – Optional cache key. If not provided, no caching will be performed.
- Returns:
Broadcasted object on all ranks.
- Return type:
Any
- Raises:
ValueError – If object exists on multiple ranks or no ranks.
- clear_broadcast_cache()#
Clear the broadcast object cache.
This can be useful for testing or if the objects being broadcast might change during the lifetime of the mapping.
- clear_tensor_spec_output_cache()#
Clear the tensor spec output cache.
This can be useful for testing or if the tensor spec output might change during the lifetime of the mapping.
- broadcast_tensor_to_tp_ranks(
- tensor: torch.Tensor,
- src_rank: int = 0,
Broadcast a tensor to all TP ranks.
- Parameters:
tensor (torch.Tensor) – The tensor to broadcast.
src_rank (int, optional) – The source rank within the TP group. Defaults to 0.
- Returns:
The broadcasted tensor.
- Return type:
torch.Tensor
- scatter_to_tp_ranks(
- splits: Optional[List[torch.Tensor]],
- output_shape: torch.Size,
- dtype: torch.dtype,
- device: torch.device,
- src_rank: int = 0,
Scatter tensor splits to TP ranks.
- Parameters:
splits (Optional[List[torch.Tensor]]) – A list of tensor shards to scatter. Only rank
src_rank
needs this.output_shape (torch.Size) – The shape of the output tensor on each rank.
dtype (torch.dtype) – The data type of the output tensor.
device (torch.device) – The device for the output tensor.
src_rank (int, optional) – The source rank for the scatter operation. Defaults to 0.
- Returns:
The scattered tensor shard on the current rank.
- Return type:
torch.Tensor
- gather_from_tp_ranks(
- tensor: torch.Tensor,
Gather tensors from all TP ranks.
- Parameters:
tensor (torch.Tensor) – The tensor shard to be gathered from the current rank.
- Returns:
A list of tensor shards from all TP ranks.
- Return type:
List[torch.Tensor]
- _count_wildcard_groups(pattern: str) int #
Count the number of wildcard capture groups in a pattern.
- Parameters:
pattern – Pattern string with * and ** wildcards
- Returns:
Number of capture groups that will be generated
.. note::
** counts as 1 group, * counts as 1 group ** must be counted before * to avoid double-counting
- _validate_patterns()#
Validate wildcard consistency between patterns.
- _normalize_expert_param_name(param_name: str) str #
Normalize expert parameter name by replacing trailing numbers with 0. e.g. experts.weight15 -> experts.weight0, experts.bias15 -> experts.bias0
- Parameters:
param_name (str) – Parameter name that may end with a number.
- Returns:
Parameter name with trailing number replaced by 0.
- Return type:
str
- _get_config(module: torch.nn.Module) Any #
Extract configuration from module hierarchy.
- gather_from_ep_ranks(
- megatron_weights: Optional[torch.Tensor],
- megatron_module: Optional[megatron.core.transformer.module.MegatronModule],
- hf_param_name: Optional[str],
Handle expert parallel weight gathering for MoE models.
This method gathers expert weights across expert-parallel (EP) ranks and returns a mapping from HF parameter names to the corresponding tensors from each EP rank. Call this only for confirmed expert parameters (self.is_expert is True), typically after TP gathering/concatenation in the export path (Megatron → HF).
Behavior and notation:
Let E be the total number of experts (e.g., config.num_moe_experts) and S be the expert-parallel size (ep_size). We assume E % S == 0.
Each EP rank owns E/S experts. For a given parameter name, we infer a local expert index L (0 ≤ L < E/S) on the current EP rank from the global expert id embedded in the name (works for both .weight and .bias).
The set of global expert ids that correspond to this local index L across all EP ranks is: {L + k * (E/S) | k ∈ [0, S-1]}.
Communication and outputs:
We perform an all_gather over the EP group to collect the tensor from every EP rank into a list ordered by EP rank id.
For each EP rank k, we construct the HF parameter name by replacing the expert id in
hf_param_name
with (L + k * (E/S)), preserving the rest of the path, and map that name to the gathered tensor from rank k.
Example:
E = 8, S = 2 → E/S = 4. Experts are distributed as: Rank 0: [0, 1, 2, 3], Rank 1: [4, 5, 6, 7]. If the local index L = 0 (derived from the param name), this returns: {”…experts.0.weight”: tensor_from_rank0, “…experts.4.weight”: tensor_from_rank1}
- Parameters:
megatron_weights (Optional[torch.Tensor]) – The local expert weight tensor (after any TP handling) on this EP rank.
megatron_module (Optional[MegatronModule]) – The Megatron module containing configuration (used to determine E and E/S). Can be None on non-owning PP ranks; values will be broadcast across PP.
hf_param_name (Optional[str]) – HF parameter name template for the current (local) expert on this rank. The expert id within this string is replaced with the appropriate global expert ids for each EP rank.
- Returns:
Mapping from HF parameter names (one per EP rank) to the corresponding expert tensors gathered from each EP rank.
- Return type:
Dict[str, torch.Tensor]
- maybe_dequantize(tensor: torch.Tensor) torch.Tensor #
Dequantize FP8 tensor if needed.
- class bridge.models.conversion.param_mapping.DirectMapping#
Bases:
bridge.models.conversion.param_mapping.MegatronParamMapping
[torch.Tensor
]Direct 1:1 weight mapping with no transformation or tensor parallelism.
- hf_to_megatron(
- hf_weights: torch.Tensor,
- megatron_module: torch.nn.Module,
Direct copy - no transformation or distribution.
- megatron_to_hf(
- megatron_weights: Optional[torch.Tensor],
- megatron_module: Optional[torch.nn.Module],
Direct copy with PP broadcast.
- class bridge.models.conversion.param_mapping.ColumnParallelMapping#
Bases:
bridge.models.conversion.param_mapping.MegatronParamMapping
[torch.Tensor
]Mapping for column-parallel linear and embedding weights.
Column-parallel layers in Megatron split the output dimension across tensor parallel ranks. This is used for layers where each rank computes a portion of the output features independently, such as:
Embedding layers (split vocabulary)
Linear layers producing hidden states (e.g., QKV projections, MLP up projections)
The weight matrix is partitioned along dimension 0 (rows), so each TP rank holds a subset of output features while maintaining all input features.
Sharding pattern
Original weight:
[output_features, input_features]
Rank 0:
[output_features/tp_size, input_features]
Rank 1:
[output_features/tp_size, input_features]
…
Forward path (HuggingFace → Megatron)
Validate divisibility: output dimension must be divisible by tp_size
Split: Chunk tensor along dim 0 into tp_size equal parts
Scatter: Distribute chunks to respective TP ranks
Reverse path (Megatron → HuggingFace)
Broadcast: Ensure all PP ranks have the tensor
Gather: Collect chunks from all TP ranks
Concatenate: Reassemble along dim 0 on rank 0
.. rubric:: Example
.. code-block:: python
# For a weight of shape [4096, 1024] with tp_size=4: # Each rank gets [1024, 1024] after column-parallel split mapping = ColumnParallelMapping("linear.weight", "transformer.linear.weight") megatron_weights = mapping.hf_to_megatron(hf_weight, megatron_module) # megatron_weights.shape = [1024, 1024] on each rank
.. note::
This mapping also handles bias terms, which are 1D tensors split along their only dimension following the same pattern.
- hf_to_megatron(
- hf_weights: torch.Tensor,
- megatron_module: torch.nn.Module,
Split weight along dim 0 and distribute to TP ranks.
- megatron_to_hf(
- megatron_weights: Optional[torch.Tensor],
- megatron_module: Optional[torch.nn.Module],
Gather from all TP ranks and concatenate.
- class bridge.models.conversion.param_mapping.RowParallelMapping#
Bases:
bridge.models.conversion.param_mapping.MegatronParamMapping
[torch.Tensor
]Mapping for row-parallel linear weights.
Megatron shards row-parallel tensors along dimension 1 (the input dimension of a linear layer).
Forward path (external → Megatron)
Rank 0 validates that the second dimension is divisible by
tp_size
.Rank 0 splits the tensor with
torch.chunk(..., dim=1)
producingtp_size
equally-sized shards.The shards are scattered so that every TP rank receives exactly one shard matching the shape of its local Megatron parameter.
Reverse path (Megatron → external)
The local Megatron parameter (which may live on any PP rank) is broadcast to all PP ranks so that the gather step can be collective.
All TP ranks gather their shard.
Rank 0 concatenates the gathered list along dim 1 to reconstruct the original unsharded weight and emits it under the external (HF) name.
- hf_to_megatron(
- hf_weights: torch.Tensor,
- megatron_module: torch.nn.Module,
Split weight along dim 1 and distribute to TP ranks.
- megatron_to_hf(
- megatron_weights: Optional[torch.Tensor],
- megatron_module: Optional[torch.nn.Module],
Gather from all TP ranks and concatenate.
- class bridge.models.conversion.param_mapping.ReplicatedMapping#
Bases:
bridge.models.conversion.param_mapping.MegatronParamMapping
[torch.Tensor
]Mapping for weights that are fully replicated across TP ranks.
Examples: layer-norm scales, biases, router weights in MoE, etc.
These tensors exist in exactly the same form on every TP rank, so the mapping logic is trivial – but we still need to broadcast across TP ranks during load (HF → Megatron) and ensure we do not emit duplicates during export (Megatron → HF).
- hf_to_megatron(
- hf_weights: torch.Tensor,
- megatron_module: torch.nn.Module,
Replicate weight to all TP ranks.
- megatron_to_hf(
- megatron_weights: Optional[torch.Tensor],
- megatron_module: Optional[torch.nn.Module],
Return weight only from rank 0 to avoid duplication.
- class bridge.models.conversion.param_mapping.AutoMapping(megatron_param: str, hf_param: str)#
Bases:
bridge.models.conversion.param_mapping.MegatronParamMapping
[torch.Tensor
]Smart mapping that automatically detects and applies the correct parallelism strategy.
This mapping eliminates the need to manually specify whether a layer is column-parallel, row-parallel, or replicated. It examines the Megatron module at runtime and delegates to the appropriate specialized mapping.
Detection strategy
Check module class name against a registry of known types
If unknown, examine module attributes (tensor_model_parallel, partition_dim)
Delegate to appropriate mapping: ColumnParallel, RowParallel, or Replicated
This abstraction is particularly useful for model-agnostic code where you don’t know the parallelism type ahead of time, or when working with models that mix different parallelism strategies.
Built-in module recognition
Column-parallel:
ColumnParallelLinear
,VocabParallelEmbedding
, etc.Row-parallel:
RowParallelLinear
,TERowParallelLinear
Replicated:
LayerNorm
,RMSNorm
, and other normalization layers
.. rubric:: Example
.. code-block:: python
# Automatically handles any weight type mapping = AutoMapping( megatron_param="decoder.layers.*.mlp.linear_fc1.weight", hf_param="model.layers.*.mlp.gate_proj.weight" ) # Works with column-parallel layers megatron_weights = mapping.hf_to_megatron(hf_weight, column_parallel_module) # Also works with normalization layers norm_weight = mapping.hf_to_megatron(hf_norm, layer_norm_module) # Register custom module types AutoMapping.register_module_type("MyCustomLinear", "column")
.. note::
If the parallelism type cannot be determined, the mapping will raise a descriptive error suggesting how to fix the issue.
Initialization
Initialize TP-aware mapping.
- _MODULE_TYPE_REGISTRY: Dict[str, set]#
None
- classmethod register_module_type(module_name: str, parallelism_type: str)#
Register a new module type for automatic parallelism detection.
- Parameters:
module_name (str) – The name of the module class (e.g., ‘MyColumnLinear’).
parallelism_type (str) – One of ‘column’, ‘row’, or ‘replicated’.
- _get_or_create_mapping(
- parallelism_type: str,
Get or create the appropriate mapping for the given type.
- _detect_parallelism_type(module: torch.nn.Module) str #
Detect parallelism type from module.
- hf_to_megatron(
- hf_weights: torch.Tensor,
- megatron_module: torch.nn.Module,
Delegate to appropriate mapping based on module type.
- megatron_to_hf(
- megatron_weights: Optional[torch.Tensor],
- megatron_module: Optional[torch.nn.Module],
Delegate to appropriate mapping based on module type.
- class bridge.models.conversion.param_mapping.QKVMapping(megatron_param: str, q: str, k: str, v: str)#
Bases:
bridge.models.conversion.param_mapping.MegatronParamMapping
[typing.Dict
[str
,torch.Tensor
]]Mapping for interleaved Query/Key/Value attention projection weights.
This mapping handles the conversion between separate Q, K, V matrices used in standard transformers and Megatron’s optimized interleaved format. The interleaving pattern groups queries with their corresponding key-value pairs to maximize GEMM efficiency during attention computation.
External format (HuggingFace)
Separate tensors:
q_proj
,k_proj
,v_proj
Each of shape
[hidden_size, hidden_size]
or[hidden_size, head_dim * num_heads]
Megatron format
Single interleaved tensor following grouped query attention (GQA) pattern
Interleaving order:
[q1...qn, k1, v1, q1...qn, k2, v2, ...]
Where
n = num_attention_heads / num_query_groups
Key features
Format conversion: Handles merging/splitting with proper interleaving
Grouped Query Attention: Supports different numbers of Q and KV heads
Tensor parallelism: Delegates to AutoMapping for distribution
.. rubric:: Example
.. code-block:: python
# Create mapping for attention weights mapping = QKVMapping( megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", q="model.layers.*.self_attn.q_proj.weight", k="model.layers.*.self_attn.k_proj.weight", v="model.layers.*.self_attn.v_proj.weight" ) # Convert from HuggingFace to Megatron qkv_weights = {"q": q_tensor, "k": k_tensor, "v": v_tensor} megatron_qkv = mapping.hf_to_megatron(qkv_weights, megatron_module) # Convert from Megatron to HuggingFace hf_weights = mapping.megatron_to_hf(megatron_qkv, megatron_module) # Returns: {"q_proj.weight": ..., "k_proj.weight": ..., "v_proj.weight": ...}
.. note::
This mapping automatically handles both regular multi-head attention (same number of Q, K, V heads) and grouped query attention (fewer KV heads than Q heads) based on the model configuration.
Initialization
Initialize QKV mapping.
- Parameters:
megatron_param (str) – Megatron QKV parameter name pattern.
q (str) – Query weight name pattern.
k (str) – Key weight name pattern.
v (str) – Value weight name pattern.
- hf_to_megatron(
- hf_weights: Dict[str, torch.Tensor],
- megatron_module: torch.nn.Module,
Merge Q, K, V into interleaved format and distribute.
- megatron_to_hf(
- megatron_weights: Optional[torch.Tensor],
- megatron_module: Optional[torch.nn.Module],
Gather QKV shards and split into Q, K, V.
- resolve(
- captures: Tuple[str, ...],
Return a new resolved QKVMapping instance.
- class bridge.models.conversion.param_mapping.GatedMLPMapping(megatron_param: str, gate: str, up: str)#
Bases:
bridge.models.conversion.param_mapping.MegatronParamMapping
[typing.Dict
[str
,torch.Tensor
]]Mapping for gated-MLP projection weights (SwiGLU / GeGLU).
Checkpoint formats expose two independent matrices:
G – gate projection
U – up projection
Megatron concatenates them row-wise (
[G; U]
) so that a single GEMM can produce both activations.Responsibilities handled by this mapping
Concatenate / split – convert between
[G; U]
(Megatron) and the separate{G, U}
matrices (external).Tensor-parallel distribution – correctly splits gate and up projections separately before concatenating corresponding shards, ensuring each TP rank gets the proper [gate_shard; up_shard] format.
TP Distribution Strategy For tensor parallelism, this mapping:
Splits gate and up matrices separately along output dimension (dim 0)
Concatenates corresponding shards: [gate_shard_i; up_shard_i] for rank i
This ensures each rank’s concatenated tensor matches the expected shape
Initialization
Initialize gated MLP mapping.
- Parameters:
megatron_param (str) – Megatron MLP parameter name pattern.
gate (str) – Gate projection weight name pattern.
up (str) – Up projection weight name pattern.
- hf_to_megatron(
- hf_weights: Dict[str, torch.Tensor],
- megatron_module: torch.nn.Module,
Split gate and up separately, then concatenate corresponding shards.
- megatron_to_hf(
- megatron_weights: Optional[torch.Tensor],
- megatron_module: Optional[torch.nn.Module],
Gather concatenated shards and split into gate and up.
- resolve(
- captures: Tuple[str, ...],
Return a new resolved GatedMLPMapping instance.
- bridge.models.conversion.param_mapping.merge_qkv_biases(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
Merge separate Q, K, V bias vectors into Megatron’s interleaved QKV format.
- Parameters:
config (TransformerConfig) – Transformer configuration.
q (torch.Tensor) – Query projection biases [hidden_size].
k (torch.Tensor) – Key projection biases [kv_hidden_size].
v (torch.Tensor) – Value projection biases [kv_hidden_size].
- Returns:
Interleaved QKV biases in Megatron format as 1D tensor.
- Return type:
torch.Tensor
- bridge.models.conversion.param_mapping.split_qkv_biases(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- qkv: torch.Tensor,
Split Megatron’s interleaved QKV bias into separate Q, K, V biases.
- Parameters:
config (TransformerConfig) – Transformer configuration.
qkv (torch.Tensor) – Interleaved QKV biases in Megatron format (1D tensor).
- Returns:
Tuple of (Q, K, V) bias vectors.
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- bridge.models.conversion.param_mapping.merge_qkv_weights(
- provider: megatron.core.transformer.transformer_config.TransformerConfig,
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
Merge separate Q, K, V weight matrices into Megatron’s interleaved QKV format.
- Parameters:
provider (TransformerConfig) – Model configuration provider.
q (torch.Tensor) – Query projection weights [hidden_size, hidden_size] or bias [hidden_size].
k (torch.Tensor) – Key projection weights [kv_hidden_size, hidden_size] or bias [kv_hidden_size].
v (torch.Tensor) – Value projection weights [kv_hidden_size, hidden_size] or bias [kv_hidden_size].
- Returns:
Interleaved QKV weights in Megatron format.
- Return type:
torch.Tensor
- bridge.models.conversion.param_mapping.split_qkv_weights(
- provider: megatron.core.transformer.transformer_config.TransformerConfig,
- qkv: torch.Tensor,
Split Megatron’s interleaved QKV tensor into separate Q, K, V matrices.
- Parameters:
provider (TransformerConfig) – Model configuration provider.
qkv (torch.Tensor) – Interleaved QKV weights in Megatron format.
- Returns:
Tuple of (Q, K, V) weight matrices.
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]