nemo_automodel.components.models.common.combined_projection.state_dict_adapter#
Generic state dict adapter for models with combined projections.
This module provides a unified state dict converter that handles:
Separate q_proj, k_proj, v_proj <-> Combined qkv_proj
Separate gate_proj, up_proj <-> Combined gate_up_proj
Tied weights (lm_head <-> embed_tokens)
Works with any transformer model (Llama, Qwen2, etc.) that uses these projection patterns.
Module Contents#
Classes#
Generic adapter for converting between HF and combined-projection formats. |
Functions#
Check if tensor is a DTensor without importing DTensor directly. |
|
Split tensor handling both regular tensors and DTensors without triggering redistribution. |
|
Concatenate tensors handling both regular tensors and DTensors without triggering redistribution. |
Data#
API#
- nemo_automodel.components.models.common.combined_projection.state_dict_adapter.logger#
‘getLogger(…)’
- nemo_automodel.components.models.common.combined_projection.state_dict_adapter._is_dtensor(tensor: torch.Tensor) bool#
Check if tensor is a DTensor without importing DTensor directly.
- nemo_automodel.components.models.common.combined_projection.state_dict_adapter._safe_split(
- tensor: torch.Tensor,
- split_sizes: list[int],
- dim: int = 0,
Split tensor handling both regular tensors and DTensors without triggering redistribution.
For DTensors, extracts local shard, splits it, and rewraps each piece as DTensor. For regular tensors, performs normal split.
- Parameters:
tensor – Tensor to split (can be DTensor or regular tensor)
split_sizes – Split sizes computed based on global/local tensor size
dim – Dimension to split along
- nemo_automodel.components.models.common.combined_projection.state_dict_adapter._safe_concat(
- tensors: list[torch.Tensor],
- dim: int = 0,
Concatenate tensors handling both regular tensors and DTensors without triggering redistribution.
For DTensors, extracts local shards, concatenates them, and rewraps as DTensor. For regular tensors, performs normal concat.
- Parameters:
tensors – List of tensors to concatenate (all DTensors or all regular tensors)
dim – Dimension to concatenate along
- class nemo_automodel.components.models.common.combined_projection.state_dict_adapter.CombinedProjectionStateDictAdapter(config)#
Generic adapter for converting between HF and combined-projection formats.
Handles conversion of:
Separate q_proj, k_proj, v_proj <-> Combined qkv_proj
Separate gate_proj, up_proj <-> Combined gate_up_proj
Tied weights (lm_head <-> embed_tokens) for loading HF checkpoints
Works with any transformer model config that has:
num_hidden_layers
num_attention_heads
num_key_value_heads
hidden_size
- Parameters:
config – Model config (LlamaConfig, Qwen2Config, etc.)
.. rubric:: Example
For Llama#
from transformers import LlamaConfig adapter = CombinedProjectionStateDictAdapter(LlamaConfig.from_pretrained(“meta-llama/Llama-3-8B”))
For Qwen2#
from transformers import Qwen2Config adapter = CombinedProjectionStateDictAdapter(Qwen2Config.from_pretrained(“Qwen/Qwen2.5-7B”))
Initialization
Initialize the adapter with model config.
- from_hf(
- hf_state_dict: dict[str, Any],
- **kwargs,
Convert HuggingFace state dict to combined-projection format.
Converts separate Q/K/V and gate/up projections to combined projections. Also handles tied weights (lm_head <-> embed_tokens) by copying embed_tokens to lm_head if lm_head is missing (common in HF Qwen2 and Llama checkpoints).
- Parameters:
hf_state_dict – State dict from HuggingFace model
- Returns:
State dict in combined-projection format
- to_hf(
- state_dict: dict[str, Any],
- exclude_key_regex: Optional[str] = None,
- **kwargs,
Convert combined-projection state dict to HuggingFace format.
Splits combined qkv_proj and gate_up_proj back to separate projections. Handles both full (unsharded) and TP-sharded tensors.
- Parameters:
state_dict – State dict from custom model (can be TP-sharded DTensors)
exclude_key_regex – Optional regex pattern to exclude keys
- Returns:
State dict in HuggingFace format