# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Generator, Optional, Tuple
import torch
[docs]
class ModelConverter(ABC):
"""Abstract class that defines the interface for a converter that implements model-specific conversion functions for deploying NeMo checkpoints on vLLM."""
def __init__(self, model_type: str):
self.model_type = model_type
[docs]
@abstractmethod
def get_architecture(self) -> Optional[str]:
"""Returns the HF architecture name for the current model, such as 'LlamaForCausalLM'."""
pass
[docs]
def convert_config(self, nemo_model_config: dict, hf_config: dict) -> None:
"""Implements any custom HF configuration adjustments in the 'hf_config' dict that are necessary for this model after the common translation takes place in NemoModelConfig's constructor."""
pass
[docs]
@abstractmethod
def convert_weights(
self, nemo_model_config: dict, state_dict: dict
) -> Generator[Tuple[str, torch.tensor], None, None]:
"""Returns or yields a sequence of (name, tensor) tuples that contain model weights in the HF format."""
pass
[docs]
def requires_bos_token(self) -> bool:
"""Returns True if the model requires a 'bos' token to be used at the beginning of the input sequence.
NeMo checkpoints do not store this information.
"""
return False
[docs]
class LlamaConverter(ModelConverter):
[docs]
def get_architecture(self):
if self.model_type == "llama":
return "LlamaForCausalLM"
if self.model_type == "mistral":
return "MistralForCausalLM"
return None
[docs]
def convert_weights(self, nemo_model_config, state_dict):
hidden_size = nemo_model_config["hidden_size"]
head_num = nemo_model_config["num_attention_heads"]
num_query_groups = nemo_model_config["num_query_groups"]
num_layers = nemo_model_config["num_layers"]
head_size = hidden_size // head_num
heads_per_group = head_num // num_query_groups
qkv_total_dim = head_num + 2 * num_query_groups
yield (
"model.embed_tokens.weight",
state_dict["model.embedding.word_embeddings.weight"],
)
yield ("model.norm.weight", state_dict["model.decoder.final_layernorm.weight"])
if not nemo_model_config.get("share_embeddings_and_output_weights", False):
yield ("lm_head.weight", state_dict["model.output_layer.weight"])
for layer in range(int(num_layers)):
qkv_weights = state_dict["model.decoder.layers.self_attention.linear_qkv.weight"][layer]
qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size])
q_slice = torch.cat(
[
torch.arange(
(heads_per_group + 2) * i,
(heads_per_group + 2) * i + heads_per_group,
)
for i in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
for name, slice in [
("q_proj", q_slice),
("k_proj", k_slice),
("v_proj", v_slice),
]:
weight_name = f"model.layers.{layer}.self_attn.{name}.weight"
yield (weight_name, qkv_weights[slice].reshape(-1, hidden_size))
linear_proj_weight = state_dict["model.decoder.layers.self_attention.linear_proj.weight"][layer]
yield (f"model.layers.{layer}.self_attn.o_proj.weight", linear_proj_weight)
gate_proj_weight, up_proj_weight = torch.chunk(
state_dict["model.decoder.layers.mlp.linear_fc1.weight"][layer],
2,
dim=0,
)
yield (f"model.layers.{layer}.mlp.gate_proj.weight", gate_proj_weight)
yield (f"model.layers.{layer}.mlp.up_proj.weight", up_proj_weight)
mlp_up_weight = state_dict["model.decoder.layers.mlp.linear_fc2.weight"][layer]
yield (f"model.layers.{layer}.mlp.down_proj.weight", mlp_up_weight)
input_layernorm_weight = state_dict["model.decoder.layers.self_attention.linear_qkv.layer_norm_weight"][
layer
]
yield (
f"model.layers.{layer}.input_layernorm.weight",
input_layernorm_weight,
)
post_attn_layernorm_weight = state_dict["model.decoder.layers.mlp.linear_fc1.layer_norm_weight"][layer]
yield (
f"model.layers.{layer}.post_attention_layernorm.weight",
post_attn_layernorm_weight,
)
[docs]
def requires_bos_token(self):
return True
[docs]
class MixtralConverter(ModelConverter):
[docs]
def get_architecture(self):
if self.model_type == "mixtral":
return "MixtralForCausalLM"
return None
[docs]
def convert_weights(self, nemo_model_config, state_dict):
hidden_size = nemo_model_config["hidden_size"]
head_num = nemo_model_config["num_attention_heads"]
num_query_groups = nemo_model_config["num_query_groups"]
num_layers = nemo_model_config["num_layers"]
num_moe_experts = nemo_model_config["num_moe_experts"]
head_size = hidden_size // head_num
heads_per_group = head_num // num_query_groups
qkv_total_dim = head_num + 2 * num_query_groups
yield (
"model.embed_tokens.weight",
state_dict["model.embedding.word_embeddings.weight"],
)
yield ("model.norm.weight", state_dict["model.decoder.final_layernorm.weight"])
yield ("lm_head.weight", state_dict["model.output_layer.weight"])
for layer in range(int(num_layers)):
qkv_weights = state_dict["model.decoder.layers.self_attention.linear_qkv.weight"][layer]
qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size])
q_slice = torch.cat(
[
torch.arange(
(heads_per_group + 2) * i,
(heads_per_group + 2) * i + heads_per_group,
)
for i in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
for name, slice in [
("q_proj", q_slice),
("k_proj", k_slice),
("v_proj", v_slice),
]:
weight_name = f"model.layers.{layer}.self_attn.{name}.weight"
yield (weight_name, qkv_weights[slice].reshape(-1, hidden_size))
linear_proj_weight = state_dict["model.decoder.layers.self_attention.linear_proj.weight"][layer]
yield (f"model.layers.{layer}.self_attn.o_proj.weight", linear_proj_weight)
mlp_router_weight = state_dict["model.decoder.layers.mlp.router.weight"][layer]
yield (
f"model.layers.{layer}.block_sparse_moe.gate.weight",
mlp_router_weight,
)
for expert in range(num_moe_experts):
linear_fc1_weight = state_dict["model.decoder.layers.mlp.experts.experts.linear_fc1.weight"][layer][
expert
]
gate_proj_weight, up_proj_weight = torch.chunk(linear_fc1_weight, 2, dim=0)
yield (
f"model.layers.{layer}.block_sparse_moe.experts.{expert}.w1.weight",
gate_proj_weight,
)
yield (
f"model.layers.{layer}.block_sparse_moe.experts.{expert}.w3.weight",
up_proj_weight,
)
linear_fc2_weight = state_dict["model.decoder.layers.mlp.experts.experts.linear_fc2.weight"][layer][
expert
]
yield (
f"model.layers.{layer}.block_sparse_moe.experts.{expert}.w2.weight",
linear_fc2_weight,
)
input_layernorm_weight = state_dict["model.decoder.layers.self_attention.linear_qkv.layer_norm_weight"][
layer
]
yield (
f"model.layers.{layer}.input_layernorm.weight",
input_layernorm_weight,
)
post_attn_layernorm_weight = state_dict["model.decoder.layers.pre_mlp_layernorm.weight"][layer]
yield (
f"model.layers.{layer}.post_attention_layernorm.weight",
post_attn_layernorm_weight,
)
[docs]
def requires_bos_token(self):
return True
[docs]
class GemmaConverter(ModelConverter):
[docs]
def get_architecture(self):
if self.model_type == "gemma":
return "GemmaForCausalLM"
return None
[docs]
def convert_weights(self, nemo_model_config, state_dict):
num_layers = nemo_model_config["num_layers"]
num_query_groups = nemo_model_config["num_query_groups"]
head_num = nemo_model_config["num_attention_heads"]
head_size = nemo_model_config["kv_channels"]
hidden_size = nemo_model_config["hidden_size"]
zero_centered_gamma = nemo_model_config.get("layernorm_zero_centered_gamma", False)
heads_per_group = head_num // num_query_groups
yield (
"model.embed_tokens.weight",
state_dict["model.embedding.word_embeddings.weight"],
)
final_layernorm_weight = state_dict["model.decoder.final_layernorm.weight"]
if not zero_centered_gamma:
final_layernorm_weight -= 1.0
yield ("model.norm.weight", final_layernorm_weight)
for layer in range(int(num_layers)):
input_layernorm_weight = state_dict["model.decoder.layers.self_attention.linear_qkv.layer_norm_weight"][
layer
]
if not zero_centered_gamma:
input_layernorm_weight -= 1.0
yield (
f"model.layers.{layer}.input_layernorm.weight",
input_layernorm_weight,
)
post_attention_layernorm_weight = state_dict["model.decoder.layers.mlp.linear_fc1.layer_norm_weight"][layer]
if not zero_centered_gamma:
post_attention_layernorm_weight -= 1.0
yield (
f"model.layers.{layer}.post_attention_layernorm.weight",
post_attention_layernorm_weight,
)
gate_up_combined_weight = state_dict["model.decoder.layers.mlp.linear_fc1.weight"][layer]
gate_size = gate_up_combined_weight.shape[0] // 2
yield (
f"model.layers.{layer}.mlp.gate_proj.weight",
gate_up_combined_weight[:gate_size, :],
)
yield (
f"model.layers.{layer}.mlp.up_proj.weight",
gate_up_combined_weight[gate_size:, :],
)
down_proj_weight = state_dict["model.decoder.layers.mlp.linear_fc2.weight"][layer]
yield (f"model.layers.{layer}.mlp.down_proj.weight", down_proj_weight)
self_attn_o_proj_weight = state_dict["model.decoder.layers.self_attention.linear_proj.weight"][layer]
yield (
f"model.layers.{layer}.self_attn.o_proj.weight",
self_attn_o_proj_weight,
)
qkv_weight = state_dict["model.decoder.layers.self_attention.linear_qkv.weight"][layer]
qkv_intermediate_size = head_num + 2 * num_query_groups
qkv_weight = qkv_weight.reshape(qkv_intermediate_size, head_size, hidden_size)
q_weight = torch.empty((head_num, head_size, hidden_size), dtype=qkv_weight.dtype)
k_weight = torch.empty((num_query_groups, head_size, hidden_size), dtype=qkv_weight.dtype)
v_weight = torch.empty((num_query_groups, head_size, hidden_size), dtype=qkv_weight.dtype)
ptr = 0
for i in range(num_query_groups):
q_weight[i * heads_per_group : (i + 1) * heads_per_group, :, :] = qkv_weight[
ptr : ptr + heads_per_group, ::
]
ptr += heads_per_group
k_weight[i : i + 1, :, :] = qkv_weight[ptr : ptr + 1, :, :]
ptr += 1
v_weight[i : i + 1, :, :] = qkv_weight[ptr : ptr + 1, :, :]
ptr += 1
assert ptr == qkv_intermediate_size
q_weight = q_weight.reshape(head_num * head_size, hidden_size)
k_weight = k_weight.reshape(num_query_groups * head_size, hidden_size)
v_weight = v_weight.reshape(num_query_groups * head_size, hidden_size)
yield (f"model.layers.{layer}.self_attn.q_proj.weight", q_weight)
yield (f"model.layers.{layer}.self_attn.k_proj.weight", k_weight)
yield (f"model.layers.{layer}.self_attn.v_proj.weight", v_weight)
[docs]
def requires_bos_token(self):
return True
[docs]
class Starcoder2Converter(ModelConverter):
[docs]
def get_architecture(self):
if self.model_type == "starcoder2":
return "Starcoder2ForCausalLM"
return None
[docs]
def convert_config(self, nemo_model_config, hf_config):
window_sizes = nemo_model_config.get("window_size")
if window_sizes is not None:
hf_config["sliding_window"] = window_sizes[0]
# 'tie_word_embeddings = False' means that there is a 'lm_head.weight' tensor.
# This converter assumes that it's always there.
# If there is a version of starcoder2 where it's not there, we'll need to copy
# 'model.embed_tokens.weight' into 'lm_head.weight' and still set 'tie_word_embeddings = False'
# because at this point we don't know if the weight is there or not, and this configuration
# is not stored in NeMo checkpoints.
hf_config["tie_word_embeddings"] = False
[docs]
def convert_weights(self, nemo_model_config, state_dict):
num_layers = nemo_model_config["num_layers"]
num_query_groups = nemo_model_config["num_query_groups"]
head_num = nemo_model_config["num_attention_heads"]
hidden_size = nemo_model_config["hidden_size"]
head_size = hidden_size // head_num
heads_per_group = head_num // num_query_groups
qkv_total_dim = head_num + 2 * num_query_groups
if "bias" in nemo_model_config:
has_bias = nemo_model_config["bias"]
else:
has_bias = nemo_model_config["add_bias_linear"]
yield (
"model.embed_tokens.weight",
state_dict["model.embedding.word_embeddings.weight"],
)
yield ("model.norm.weight", state_dict["model.decoder.final_layernorm.weight"])
if has_bias:
yield ("model.norm.bias", state_dict["model.decoder.final_layernorm.bias"])
yield ("lm_head.weight", state_dict["model.output_layer.weight"])
for layer in range(int(num_layers)):
# q,k,v
qkv_weights = state_dict["model.decoder.layers.self_attention.linear_qkv.weight"][layer]
qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size])
if has_bias:
qkv_bias = state_dict["model.decoder.layers.self_attention.linear_qkv.bias"][layer]
qkv_bias = qkv_bias.reshape([qkv_total_dim, head_size])
q_slice = torch.cat(
[
torch.arange(
(heads_per_group + 2) * i,
(heads_per_group + 2) * i + heads_per_group,
)
for i in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
for name, slice in [
("q_proj", q_slice),
("k_proj", k_slice),
("v_proj", v_slice),
]:
qkv_weights_slice = qkv_weights[slice].reshape(-1, hidden_size)
yield (
f"model.layers.{layer}.self_attn.{name}.weight",
qkv_weights_slice,
)
if has_bias:
qkv_bias_slice = qkv_bias[slice].reshape(-1)
yield (
f"model.layers.{layer}.self_attn.{name}.bias",
qkv_bias_slice,
)
# Attention dense
yield (
f"model.layers.{layer}.self_attn.o_proj.weight",
state_dict["model.decoder.layers.self_attention.linear_proj.weight"][layer],
)
if has_bias:
yield (
f"model.layers.{layer}.self_attn.o_proj.bias",
state_dict["model.decoder.layers.self_attention.linear_proj.bias"][layer],
)
# MLP FC1
yield (
f"model.layers.{layer}.mlp.c_fc.weight",
state_dict["model.decoder.layers.mlp.linear_fc1.weight"][layer],
)
if has_bias:
yield (
f"model.layers.{layer}.mlp.c_fc.bias",
state_dict["model.decoder.layers.mlp.linear_fc1.bias"][layer],
)
# MLP FC2
yield (
f"model.layers.{layer}.mlp.c_proj.weight",
state_dict["model.decoder.layers.mlp.linear_fc2.weight"][layer],
)
if has_bias:
yield (
f"model.layers.{layer}.mlp.c_proj.bias",
state_dict["model.decoder.layers.mlp.linear_fc2.bias"][layer],
)
# Input LayerNorm
yield (
f"model.layers.{layer}.input_layernorm.weight",
state_dict["model.decoder.layers.self_attention.linear_qkv.layer_norm_weight"][layer],
)
if has_bias:
yield (
f"model.layers.{layer}.input_layernorm.bias",
state_dict["model.decoder.layers.self_attention.linear_qkv.layer_norm_bias"][layer],
)
# Post-attention LayerNorm
yield (
f"model.layers.{layer}.post_attention_layernorm.weight",
state_dict["model.decoder.layers.mlp.linear_fc1.layer_norm_weight"][layer],
)
if has_bias:
yield (
f"model.layers.{layer}.post_attention_layernorm.bias",
state_dict["model.decoder.layers.mlp.linear_fc1.layer_norm_bias"][layer],
)
_MODEL_CONVERTERS = {
"llama": LlamaConverter,
"mistral": LlamaConverter,
"mixtral": MixtralConverter,
"gemma": GemmaConverter,
"starcoder2": Starcoder2Converter,
}
[docs]
def register_model_converter(model_type, cls):
"""Establishes a mapping from short model type to a class that converts the model from Nemo format to a vLLM compatible format."""
_MODEL_CONVERTERS[model_type] = cls
[docs]
def get_model_converter(model_type) -> Optional[ModelConverter]:
"""Returns an instance of the the model conversion class for the given model type, or None."""
cls = _MODEL_CONVERTERS.get(model_type, None)
if cls is None:
return None
return cls(model_type)