nemo_automodel.components.models.common.combined_projection.state_dict_adapter#

Generic state dict adapter for models with combined projections.

This module provides a unified state dict converter that handles:

  • Separate q_proj, k_proj, v_proj <-> Combined qkv_proj

  • Separate gate_proj, up_proj <-> Combined gate_up_proj

  • Tied weights (lm_head <-> embed_tokens)

Works with any transformer model (Llama, Qwen2, etc.) that uses these projection patterns.

Module Contents#

Classes#

CombinedProjectionStateDictAdapter

Generic adapter for converting between HF and combined-projection formats.

Functions#

_is_dtensor

Check if tensor is a DTensor without importing DTensor directly.

_safe_split

Split tensor handling both regular tensors and DTensors without triggering redistribution.

_safe_concat

Concatenate tensors handling both regular tensors and DTensors without triggering redistribution.

Data#

API#

nemo_automodel.components.models.common.combined_projection.state_dict_adapter.logger#

‘getLogger(…)’

nemo_automodel.components.models.common.combined_projection.state_dict_adapter._is_dtensor(tensor: torch.Tensor) bool#

Check if tensor is a DTensor without importing DTensor directly.

nemo_automodel.components.models.common.combined_projection.state_dict_adapter._safe_split(
tensor: torch.Tensor,
split_sizes: list[int],
dim: int = 0,
) list[torch.Tensor]#

Split tensor handling both regular tensors and DTensors without triggering redistribution.

For DTensors, extracts local shard, splits it, and rewraps each piece as DTensor. For regular tensors, performs normal split.

Parameters:
  • tensor – Tensor to split (can be DTensor or regular tensor)

  • split_sizes – Split sizes computed based on global/local tensor size

  • dim – Dimension to split along

nemo_automodel.components.models.common.combined_projection.state_dict_adapter._safe_concat(
tensors: list[torch.Tensor],
dim: int = 0,
) torch.Tensor#

Concatenate tensors handling both regular tensors and DTensors without triggering redistribution.

For DTensors, extracts local shards, concatenates them, and rewraps as DTensor. For regular tensors, performs normal concat.

Parameters:
  • tensors – List of tensors to concatenate (all DTensors or all regular tensors)

  • dim – Dimension to concatenate along

class nemo_automodel.components.models.common.combined_projection.state_dict_adapter.CombinedProjectionStateDictAdapter(config)#

Generic adapter for converting between HF and combined-projection formats.

Handles conversion of:

  • Separate q_proj, k_proj, v_proj <-> Combined qkv_proj

  • Separate gate_proj, up_proj <-> Combined gate_up_proj

  • Tied weights (lm_head <-> embed_tokens) for loading HF checkpoints

Works with any transformer model config that has:

  • num_hidden_layers

  • num_attention_heads

  • num_key_value_heads

  • hidden_size

Parameters:

config – Model config (LlamaConfig, Qwen2Config, etc.)

.. rubric:: Example

For Llama#

from transformers import LlamaConfig adapter = CombinedProjectionStateDictAdapter(LlamaConfig.from_pretrained(“meta-llama/Llama-3-8B”))

For Qwen2#

from transformers import Qwen2Config adapter = CombinedProjectionStateDictAdapter(Qwen2Config.from_pretrained(“Qwen/Qwen2.5-7B”))

Initialization

Initialize the adapter with model config.

from_hf(
hf_state_dict: dict[str, Any],
**kwargs,
) dict[str, Any]#

Convert HuggingFace state dict to combined-projection format.

Converts separate Q/K/V and gate/up projections to combined projections. Also handles tied weights (lm_head <-> embed_tokens) by copying embed_tokens to lm_head if lm_head is missing (common in HF Qwen2 and Llama checkpoints).

Parameters:

hf_state_dict – State dict from HuggingFace model

Returns:

State dict in combined-projection format

to_hf(
state_dict: dict[str, Any],
exclude_key_regex: Optional[str] = None,
**kwargs,
) dict[str, Any]#

Convert combined-projection state dict to HuggingFace format.

Splits combined qkv_proj and gate_up_proj back to separate projections. Handles both full (unsharded) and TP-sharded tensors.

Parameters:
  • state_dict – State dict from custom model (can be TP-sharded DTensors)

  • exclude_key_regex – Optional regex pattern to exclude keys

Returns:

State dict in HuggingFace format