nemo_automodel.components.models.common.combined_projection.state_dict_adapter#
Generic state dict adapter for models with combined projections.
This module provides a unified state dict converter that handles:
Separate q_proj, k_proj, v_proj <-> Combined qkv_proj
Separate gate_proj, up_proj <-> Combined gate_up_proj
Tied weights (lm_head <-> embed_tokens)
Works with any transformer model (Llama, Qwen2, etc.) that uses these projection patterns.
Module Contents#
Classes#
Generic adapter for converting between HF and combined-projection formats. |
Data#
API#
- nemo_automodel.components.models.common.combined_projection.state_dict_adapter.logger#
‘getLogger(…)’
- class nemo_automodel.components.models.common.combined_projection.state_dict_adapter.CombinedProjectionStateDictAdapter(config)#
Generic adapter for converting between HF and combined-projection formats.
Handles conversion of:
Separate q_proj, k_proj, v_proj <-> Combined qkv_proj
Separate gate_proj, up_proj <-> Combined gate_up_proj
Tied weights (lm_head <-> embed_tokens) for loading HF checkpoints
Works with any transformer model config that has:
num_hidden_layers
num_attention_heads
num_key_value_heads
hidden_size
- Parameters:
config – Model config (LlamaConfig, Qwen2Config, etc.)
.. rubric:: Example
For Llama#
from transformers import LlamaConfig adapter = CombinedProjectionStateDictAdapter(LlamaConfig.from_pretrained(“meta-llama/Llama-3-8B”))
For Qwen2#
from transformers import Qwen2Config adapter = CombinedProjectionStateDictAdapter(Qwen2Config.from_pretrained(“Qwen/Qwen2.5-7B”))
Initialization
Initialize the adapter with model config.
- static _gather_1d_if_needed(
- tensor: torch.Tensor,
- divisor: int,
Gather a 1-D DTensor on dim 0 when the local shard isn’t divisible by divisor.
FSDP2’s
shard_placement_fnonly acceptsShardplacements (notReplicate), so 1-D bias vectors of combined projections end up withShard(0)even though their interleaved layout may not divide evenly across the FSDP shard count. This helper gathers such biases to full before reshape / split operations in the state-dict adapter.Weights are handled by FSDP
Shard(1), so they never need this.
- _interleave_qkv(
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
Interleave Q, K, V by KV-head groups for TP-correct ColwiseParallel sharding.
Layout: [Q_group_0 | K_0 | V_0 | Q_group_1 | K_1 | V_1 | …] where each group has (group_size * head_dim) Q rows, head_dim K rows, head_dim V rows.
- _deinterleave_qkv(
- qkv: torch.Tensor,
De-interleave QKV from KV-head-grouped layout back to separate Q, K, V.
Works for full (unsharded), TP-sharded, and FSDP-sharded tensors. When the tensor is a DTensor whose dim-0 local shard doesn’t align with the QKV group boundary, the tensor is gathered to full first.
- _interleave_gate_up(
- gate: torch.Tensor,
- up: torch.Tensor,
Interleave gate and up row-by-row for TP-correct ColwiseParallel sharding.
- _deinterleave_gate_up(
- gate_up: torch.Tensor,
De-interleave gate/up from row-interleaved layout.
- from_hf(
- hf_state_dict: dict[str, Any],
- **kwargs,
Convert HuggingFace state dict to combined-projection format.
Converts separate Q/K/V and gate/up projections to combined projections. Also handles tied weights (lm_head <-> embed_tokens) by copying embed_tokens to lm_head if lm_head is missing (common in HF Qwen2 and Llama checkpoints).
- Parameters:
hf_state_dict – State dict from HuggingFace model
- Returns:
State dict in combined-projection format
- to_hf(
- state_dict: dict[str, Any],
- exclude_key_regex: Optional[str] = None,
- **kwargs,
Convert combined-projection state dict to HuggingFace format.
Splits combined qkv_proj and gate_up_proj back to separate projections. Handles both full (unsharded) and TP-sharded tensors.
- Parameters:
state_dict – State dict from custom model (can be TP-sharded DTensors)
exclude_key_regex – Optional regex pattern to exclude keys
- Returns:
State dict in HuggingFace format
- _split_remaining_combined_projection_keys(
- hf_state_dict: dict[str, Any],
Split any remaining combined-projection keys in-place.
Handles LoRA adapter weights (lora_A, lora_B), DoRA magnitude vectors, and any base weight/bias keys that weren’t caught by the layer-indexed loop (e.g., keys with a
base_model.model.prefix from PEFT saving).For keys containing
.self_attn.qkv_proj.:lora_Aweights (input dimension) are duplicated to q/k/v projections.All other weights (lora_B, magnitude, weight, bias) are split along dim 0 using the Q/KV size ratio.
For keys containing
.mlp.gate_up_proj.:lora_Aweights are duplicated to gate/up projections.All other weights are split in half along dim 0.
- Parameters:
hf_state_dict – State dict to modify in-place.
- _recombine_split_projection_keys(
- state_dict: dict[str, Any],
Recombine split projection LoRA/DoRA keys back to combined format.
This is the reverse of
_split_remaining_combined_projection_keys. It handles LoRA adapter weights and DoRA magnitude vectors that were split for HF-PEFT compatibility duringto_hf()and need to be recombined when loading back into a model with combined projections.For keys containing
.self_attn.q_proj.<suffix>:lora_Aweights (which were duplicated during split) are deduplicated — we take theq_projversion.All other weights (
lora_B, magnitude, etc.) are concatenated along dim 0 in Q, K, V order.
For keys containing
.mlp.gate_proj.<suffix>:lora_Aweights are deduplicated — we take thegate_projversion.All other weights are concatenated along dim 0 in gate, up order.
Keys that end with
.weightor.biasdirectly on the projection (e.g.,q_proj.weight) are skipped because those are already handled by the layer-indexed loop infrom_hf.- Parameters:
state_dict – State dict to modify in-place.