# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import re
import tarfile
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Tuple
import torch
import yaml
from nemo_export.tarutils import TarPath
[docs]
def replace_number_add_offset(key, offset_value):
# This function finds the layer number in the state dict key and adds a numeric offset to that number
if offset_value == 0:
return key
pattern = r"layers.(\d+)"
def add_offset(match):
return "layers." + str(int(match.group(1)) + offset_value)
return re.sub(pattern, add_offset, key)
[docs]
def rename_qkv_keys(key):
new_keys = []
new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.q_adapter."))
new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.k_adapter."))
new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.v_adapter."))
return new_keys
[docs]
def convert_lora_weights_to_canonical(
config: Dict[str, Any], lora_weights: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""This function converts nemo style (fused) lora weights to canonical (unfused) LoRA weights.
Namely, it unfuses the QKV adapter layers and the H-to-4H adapter layers.
Returns:
Dict[str, torch.Tensor]: The new LoRA weights with unfused layers.
"""
hidden_size = int(config["hidden_size"])
num_heads = int(config["num_attention_heads"])
head_size = hidden_size // num_heads
num_query_groups = int(config.get("num_query_groups", num_heads)) # num_kv_heads
heads_per_group = num_heads // num_query_groups
qkv_total_dim = num_heads + 2 * num_query_groups
adapter_size = config["peft"]["lora_tuning"]["adapter_dim"]
q_slice = torch.cat(
[
torch.arange(
(heads_per_group + 2) * group_idx,
(heads_per_group + 2) * group_idx + heads_per_group,
)
for group_idx in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, heads_per_group + 2)
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, heads_per_group + 2)
qkv_keys_to_update = []
hto4h_keys_to_update = []
for key in lora_weights.keys():
if "lora_kqv_adapter" in key:
qkv_keys_to_update.append(key)
if "lora_hto4h_adapter" in key:
hto4h_keys_to_update.append(key)
# unfuse QKV layer
for key in qkv_keys_to_update:
if "linear_in" in key:
assert lora_weights[key].size(0) == adapter_size
for new_key in rename_qkv_keys(key):
lora_weights[new_key] = lora_weights[key]
assert len(lora_weights[new_key].size()) == 2
elif "linear_out" in key:
assert lora_weights[key].size(1) == adapter_size
for new_key, size in zip(rename_qkv_keys(key), [q_slice, k_slice, v_slice]):
lora_weights[new_key] = (
lora_weights[key]
.reshape((qkv_total_dim, head_size, adapter_size))[size]
.reshape((-1, adapter_size))
)
assert len(lora_weights[new_key].size()) == 2
lora_weights.pop(key)
# This maps to gate_up_proj in HF, but we need to split it up into gate_proj and up_proj
for key in hto4h_keys_to_update:
gate_proj_key = key.replace(".lora_hto4h_adapter.", ".lora_unfused_hto4h_adapter.gate_adapter.")
up_proj_key = key.replace(".lora_hto4h_adapter.", ".lora_unfused_hto4h_adapter.up_adapter.")
module_weight = lora_weights[key]
if "linear_in" in key:
# lora_a gets duplicated
lora_weights[gate_proj_key] = module_weight
lora_weights[up_proj_key] = module_weight
elif "linear_out" in key:
# lora_b gets split
split_size = module_weight.shape[0]
gate_up_split = module_weight.split(split_size // 2)
lora_weights[gate_proj_key] = gate_up_split[0]
lora_weights[up_proj_key] = gate_up_split[1]
lora_weights.pop(key)
return lora_weights
[docs]
def convert_lora_nemo_to_canonical(lora_nemo, save_path, hf_format=False, donor_hf_config=None):
with TarPath(lora_nemo) as archive:
with (archive / "model_config.yaml").open("r") as config_file:
lora_config = yaml.load(config_file, Loader=yaml.SafeLoader)
tp_size = lora_config.get("tensor_model_parallel_size", 1)
pp_size = lora_config.get("pipeline_model_parallel_size", 1)
lora_state_dict = [{}] * tp_size
for pp in range(pp_size):
for tp in range(tp_size):
if tp_size == 1:
ckpt_file = archive / "model_weights.ckpt"
elif pp_size == 1:
ckpt_file = archive / f"mp_rank_{tp:02d}/model_weights.ckpt"
else:
ckpt_file = archive / f"tp_rank_{tp:02d}_pp_rank_{pp:03d}/model_weights.ckpt"
with ckpt_file.open("rb") as f:
weights = torch.load(f, map_location=torch.device("cpu"))
if pp == 0:
lora_state_dict[tp] = weights
else:
# calculate layer offset
layer_offset = lora_config["num_layers"] // pp_size * pp
for key, value in weights.items():
new_key = replace_number_add_offset(key, layer_offset)
lora_state_dict[tp][new_key] = value
# TODO: currently suport tp=1
lora_state_dict = lora_state_dict[0]
if lora_config["peft"]["lora_tuning"].get("variant", "nemo") == "nemo":
lora_config["peft"]["lora_tuning"]["variant"] = "canonical"
lora_state_dict = convert_lora_weights_to_canonical(lora_config, lora_state_dict)
if hf_format:
lora_state_dict, target_modules = reformat_module_names_to_hf(lora_state_dict)
Path(save_path).mkdir(parents=True, exist_ok=True)
torch.save(lora_state_dict, f"{save_path}/adapter_model.bin")
if donor_hf_config is not None:
with open(donor_hf_config) as hf_config_file:
adapter_config = json.load(hf_config_file)
else:
adapter_config = {}
adapter_config["peft_type"] = "LORA"
adapter_config["r"] = lora_config["peft"]["lora_tuning"]["adapter_dim"]
adapter_config["lora_alpha"] = lora_config["peft"]["lora_tuning"]["alpha"]
adapter_config["target_modules"] = target_modules
with open(f"{save_path}/adapter_config.json", "w") as f:
json.dump(adapter_config, f, indent=4)
else:
with tempfile.TemporaryDirectory() as tmpdir:
with open(f"{tmpdir}/model_config.yaml", "w") as f:
yaml.dump(lora_config, f)
torch.save(lora_state_dict, f"{tmpdir}/model_weights.ckpt")
dirname = os.path.dirname(save_path)
os.makedirs(dirname, exist_ok=True)
with tarfile.open(save_path, "w:") as tar:
tar.add(tmpdir, arcname=".")
return lora_state_dict, lora_config