nemo_automodel.components.models.common.combined_projection.state_dict_adapter#

Generic state dict adapter for models with combined projections.

This module provides a unified state dict converter that handles:

  • Separate q_proj, k_proj, v_proj <-> Combined qkv_proj

  • Separate gate_proj, up_proj <-> Combined gate_up_proj

  • Tied weights (lm_head <-> embed_tokens)

Works with any transformer model (Llama, Qwen2, etc.) that uses these projection patterns.

Module Contents#

Classes#

CombinedProjectionStateDictAdapter

Generic adapter for converting between HF and combined-projection formats.

Data#

API#

nemo_automodel.components.models.common.combined_projection.state_dict_adapter.logger#

‘getLogger(…)’

class nemo_automodel.components.models.common.combined_projection.state_dict_adapter.CombinedProjectionStateDictAdapter(config)#

Generic adapter for converting between HF and combined-projection formats.

Handles conversion of:

  • Separate q_proj, k_proj, v_proj <-> Combined qkv_proj

  • Separate gate_proj, up_proj <-> Combined gate_up_proj

  • Tied weights (lm_head <-> embed_tokens) for loading HF checkpoints

Works with any transformer model config that has:

  • num_hidden_layers

  • num_attention_heads

  • num_key_value_heads

  • hidden_size

Parameters:

config – Model config (LlamaConfig, Qwen2Config, etc.)

.. rubric:: Example

For Llama#

from transformers import LlamaConfig adapter = CombinedProjectionStateDictAdapter(LlamaConfig.from_pretrained(“meta-llama/Llama-3-8B”))

For Qwen2#

from transformers import Qwen2Config adapter = CombinedProjectionStateDictAdapter(Qwen2Config.from_pretrained(“Qwen/Qwen2.5-7B”))

Initialization

Initialize the adapter with model config.

static _gather_1d_if_needed(
tensor: torch.Tensor,
divisor: int,
) torch.Tensor#

Gather a 1-D DTensor on dim 0 when the local shard isn’t divisible by divisor.

FSDP2’s shard_placement_fn only accepts Shard placements (not Replicate), so 1-D bias vectors of combined projections end up with Shard(0) even though their interleaved layout may not divide evenly across the FSDP shard count. This helper gathers such biases to full before reshape / split operations in the state-dict adapter.

Weights are handled by FSDP Shard(1), so they never need this.

_interleave_qkv(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) torch.Tensor#

Interleave Q, K, V by KV-head groups for TP-correct ColwiseParallel sharding.

Layout: [Q_group_0 | K_0 | V_0 | Q_group_1 | K_1 | V_1 | …] where each group has (group_size * head_dim) Q rows, head_dim K rows, head_dim V rows.

_deinterleave_qkv(
qkv: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

De-interleave QKV from KV-head-grouped layout back to separate Q, K, V.

Works for full (unsharded), TP-sharded, and FSDP-sharded tensors. When the tensor is a DTensor whose dim-0 local shard doesn’t align with the QKV group boundary, the tensor is gathered to full first.

_interleave_gate_up(
gate: torch.Tensor,
up: torch.Tensor,
) torch.Tensor#

Interleave gate and up row-by-row for TP-correct ColwiseParallel sharding.

_deinterleave_gate_up(
gate_up: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor]#

De-interleave gate/up from row-interleaved layout.

from_hf(
hf_state_dict: dict[str, Any],
**kwargs,
) dict[str, Any]#

Convert HuggingFace state dict to combined-projection format.

Converts separate Q/K/V and gate/up projections to combined projections. Also handles tied weights (lm_head <-> embed_tokens) by copying embed_tokens to lm_head if lm_head is missing (common in HF Qwen2 and Llama checkpoints).

Parameters:

hf_state_dict – State dict from HuggingFace model

Returns:

State dict in combined-projection format

to_hf(
state_dict: dict[str, Any],
exclude_key_regex: Optional[str] = None,
**kwargs,
) dict[str, Any]#

Convert combined-projection state dict to HuggingFace format.

Splits combined qkv_proj and gate_up_proj back to separate projections. Handles both full (unsharded) and TP-sharded tensors.

Parameters:
  • state_dict – State dict from custom model (can be TP-sharded DTensors)

  • exclude_key_regex – Optional regex pattern to exclude keys

Returns:

State dict in HuggingFace format

_split_remaining_combined_projection_keys(
hf_state_dict: dict[str, Any],
) None#

Split any remaining combined-projection keys in-place.

Handles LoRA adapter weights (lora_A, lora_B), DoRA magnitude vectors, and any base weight/bias keys that weren’t caught by the layer-indexed loop (e.g., keys with a base_model.model. prefix from PEFT saving).

For keys containing .self_attn.qkv_proj.:

  • lora_A weights (input dimension) are duplicated to q/k/v projections.

  • All other weights (lora_B, magnitude, weight, bias) are split along dim 0 using the Q/KV size ratio.

For keys containing .mlp.gate_up_proj.:

  • lora_A weights are duplicated to gate/up projections.

  • All other weights are split in half along dim 0.

Parameters:

hf_state_dict – State dict to modify in-place.

_recombine_split_projection_keys(
state_dict: dict[str, Any],
) None#

Recombine split projection LoRA/DoRA keys back to combined format.

This is the reverse of _split_remaining_combined_projection_keys. It handles LoRA adapter weights and DoRA magnitude vectors that were split for HF-PEFT compatibility during to_hf() and need to be recombined when loading back into a model with combined projections.

For keys containing .self_attn.q_proj.<suffix>:

  • lora_A weights (which were duplicated during split) are deduplicated — we take the q_proj version.

  • All other weights (lora_B, magnitude, etc.) are concatenated along dim 0 in Q, K, V order.

For keys containing .mlp.gate_proj.<suffix>:

  • lora_A weights are deduplicated — we take the gate_proj version.

  • All other weights are concatenated along dim 0 in gate, up order.

Keys that end with .weight or .bias directly on the projection (e.g., q_proj.weight) are skipped because those are already handled by the layer-indexed loop in from_hf.

Parameters:

state_dict – State dict to modify in-place.