nemo_automodel.components.models.common.combined_projection.state_dict_adapter#

Generic state dict adapter for models with combined projections.

This module provides a unified state dict converter that handles:

  • Separate q_proj, k_proj, v_proj <-> Combined qkv_proj

  • Separate gate_proj, up_proj <-> Combined gate_up_proj

  • Tied weights (lm_head <-> embed_tokens)

Works with any transformer model (Llama, Qwen2, etc.) that uses these projection patterns.

Module Contents#

Classes#

CombinedProjectionStateDictAdapter

Generic adapter for converting between HF and combined-projection formats.

Data#

API#

nemo_automodel.components.models.common.combined_projection.state_dict_adapter.logger#

‘getLogger(…)’

class nemo_automodel.components.models.common.combined_projection.state_dict_adapter.CombinedProjectionStateDictAdapter(config)#

Generic adapter for converting between HF and combined-projection formats.

Handles conversion of:

  • Separate q_proj, k_proj, v_proj <-> Combined qkv_proj

  • Separate gate_proj, up_proj <-> Combined gate_up_proj

  • Tied weights (lm_head <-> embed_tokens) for loading HF checkpoints

Works with any transformer model config that has:

  • num_hidden_layers

  • num_attention_heads

  • num_key_value_heads

  • hidden_size

Parameters:

config – Model config (LlamaConfig, Qwen2Config, etc.)

.. rubric:: Example

For Llama#

from transformers import LlamaConfig adapter = CombinedProjectionStateDictAdapter(LlamaConfig.from_pretrained(“meta-llama/Llama-3-8B”))

For Qwen2#

from transformers import Qwen2Config adapter = CombinedProjectionStateDictAdapter(Qwen2Config.from_pretrained(“Qwen/Qwen2.5-7B”))

Initialization

Initialize the adapter with model config.

static _gather_1d_bias(
tensor: torch.Tensor,
) tuple[torch.Tensor, tuple | None]#

Materialize a 1-D bias DTensor as a full tensor.

Must only be called on 1-D bias tensors. Returns (gathered, orig_placement) where orig_placement is a (device_mesh, placements) tuple that _restore_1d_bias accepts, or None when the tensor is a plain (non-DTensor) bias.

static _restore_1d_bias(
tensor: torch.Tensor,
orig_placement: tuple | None,
) torch.Tensor#

Redistribute a 1-D bias back to the placement saved by _gather_1d_bias.

No-op when orig_placement is None (plain non-DTensor bias).

_interleave_qkv(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) torch.Tensor#

Interleave Q, K, V by KV-head groups for TP-correct ColwiseParallel sharding.

Layout: [Q_group_0 | K_0 | V_0 | Q_group_1 | K_1 | V_1 | …] where each group has (group_size * head_dim) Q rows, head_dim K rows, head_dim V rows.

_deinterleave_qkv(
qkv: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

De-interleave QKV from KV-head-grouped layout back to separate Q, K, V.

_interleave_gate_up(
gate: torch.Tensor,
up: torch.Tensor,
) torch.Tensor#

Interleave gate and up row-by-row for TP-correct ColwiseParallel sharding.

_deinterleave_gate_up(
gate_up: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor]#

De-interleave gate/up from row-interleaved layout.

from_hf(
hf_state_dict: dict[str, Any],
**kwargs,
) dict[str, Any]#

Convert HuggingFace state dict to combined-projection format.

Converts separate Q/K/V and gate/up projections to combined projections. Also handles tied weights (lm_head <-> embed_tokens) by copying embed_tokens to lm_head if lm_head is missing (common in HF Qwen2 and Llama checkpoints).

Parameters:

hf_state_dict – State dict from HuggingFace model

Returns:

State dict in combined-projection format

to_hf(
state_dict: dict[str, Any],
exclude_key_regex: Optional[str] = None,
**kwargs,
) dict[str, Any]#

Convert combined-projection state dict to HuggingFace format.

Splits combined qkv_proj and gate_up_proj back to separate projections. Handles both full (unsharded) and TP-sharded tensors.

Parameters:
  • state_dict – State dict from custom model (can be TP-sharded DTensors)

  • exclude_key_regex – Optional regex pattern to exclude keys

Returns:

State dict in HuggingFace format

_split_remaining_combined_projection_keys(
hf_state_dict: dict[str, Any],
) None#

Split any remaining combined-projection keys in-place.

Handles LoRA adapter weights (lora_A, lora_B), DoRA magnitude vectors, and any base weight/bias keys that weren’t caught by the layer-indexed loop (e.g., keys with a base_model.model. prefix from PEFT saving).

For keys containing .self_attn.qkv_proj.:

  • lora_A weights (input dimension) are duplicated to q/k/v projections.

  • All other weights (lora_B, magnitude, weight, bias) are split along dim 0 using the Q/KV size ratio.

For keys containing .mlp.gate_up_proj.:

  • lora_A weights are duplicated to gate/up projections.

  • All other weights are split in half along dim 0.

Parameters:

hf_state_dict – State dict to modify in-place.

_recombine_split_projection_keys(
state_dict: dict[str, Any],
) None#

Recombine split projection LoRA/DoRA keys back to combined format.

This is the reverse of _split_remaining_combined_projection_keys. It handles LoRA adapter weights and DoRA magnitude vectors that were split for HF-PEFT compatibility during to_hf() and need to be recombined when loading back into a model with combined projections.

For keys containing .self_attn.q_proj.<suffix>:

  • lora_A weights (which were duplicated during split) are deduplicated — we take the q_proj version.

  • All other weights (lora_B, magnitude, etc.) are concatenated along dim 0 in Q, K, V order.

For keys containing .mlp.gate_proj.<suffix>:

  • lora_A weights are deduplicated — we take the gate_proj version.

  • All other weights are concatenated along dim 0 in gate, up order.

Keys that end with .weight or .bias directly on the projection (e.g., q_proj.weight) are skipped because those are already handled by the layer-indexed loop in from_hf.

Parameters:

state_dict – State dict to modify in-place.