Learning Goals
This project aims to demonstrate how to adapt or customize foundation models to improve performance on specific tasks.
This optimization process is known as fine-tuning, which involves adjusting the weights of a pre-trained foundation model with custom data.
Considering that foundation models can be significantly large, a variant of fine-tuning has gained traction recently, known as parameter-efficient fine-tuning (PEFT). PEFT encompasses several methods, including P-Tuning, LoRA, Adapters, and IA3.
For those interested in a deeper understanding of these methods, we have included a list of additional resources at the end of this document.
This project involves applying various fine-tuning methods to llama2 model. In this playbook you will implement and evaluate several parameter-efficient fine-tuning methods using a domain and task specific dataset. This playbook has been tested for P-Tuning and LoRA.
NeMo Tools and Resources
Educational Resources
Blog: Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters
NeMo documentation: Introduction to P-tuning
NeMo notebook/tutorial: Introduction to p-tuning and prompt-tuning
Software Requirements
Use the latest NeMo Framework Training container
This playbook has been tested on:
nvcr.io/ea-bignlp/ga-participants/nemofw-training:23.08.03
Hardware Requirements
Minimum 1xA100 80G for PEFT on 7B. This playbook has been tested on 8xA100 80G.
PEFT on all model sizes (7B, 13B, 70B) can run on 1 DGX A100 80G node (8xA100 80G).
If you already have a .nemo file for llama models, you can skip this step.
Step 1: Download llama2 in huggingface format
First request download permission from both Huggingface and Meta. Create the destination directory. Then you can download by either first login
mkdir llama2-7b-hf
huggingface-cli login
Or use your huggingface API token to download using the following python code
from huggingface_hub import snapshot_download
snapshot_download(repo_id="meta-llama/Llama-2-7b-hf
", local_dir="llama2-7b-hf", local_dir_use_symlinks=False, token=<YOUR HF TOKEN>)
In this example, llama2 huggingface model will be downloaded to ./llama2-7b-hf
Step 2: Convert to .nemo
Run the container using the following command
docker run --gpus device=1 --shm-size=2g --net=host --ulimit memlock=-1 --rm -it -v ${PWD}:/workspace -w /workspace -v ${PWD}/results:/results nvcr.io/ea-bignlp/ga-participants/nemofw-training:23.08.03 bash
Convert the huggingface model to .nemo model.
python /opt/NeMo/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py --in-file=./llama2-7b-hf/ --out-file=llama2-7b.nemo
The generated llama2-7b.nemo file uses distributed checkpointing and can be loaded with any tp/pp combination without reshaping/splitting.
Step 1: Download the PubMedQA dataset and run the split_dataset.py script in the cloned directory.
Download the dataset
git clone https://github.com/pubmedqa/pubmedqa.git
cd pubmedqa
cd preprocess
python split_dataset.py pqal
After running the split_dataset.py script, you will see the test_set as well as ten different directories which each contains a different train/validation fold
$ cd ../..
$ ls pubmedqa/data/
ori_pqal.json
pqal_fold0
pqal_fold1
pqal_fold2
pqal_fold3
pqal_fold4
pqal_fold5
pqal_fold6
pqal_fold7
pqal_fold8
pqal_fold9
test_ground_truth.json
test_set.json
Below is an example of what the objects look like inside of the PubMedQA train, validation and test splits
"18251357": {
"QUESTION": "Does histologic chorioamnionitis correspond to clinical chorioamnionitis?",
"CONTEXTS": [
"To evaluate the degree to which histologic chorioamnionitis, a frequent finding in placentas submitted for histopathologic evaluation, correlates with clinical indicators of infection in the mother.",
"A retrospective review was performed on 52 cases with a histologic diagnosis of acute chorioamnionitis from 2,051 deliveries at University Hospital, Newark, from January 2003 to July 2003. Third-trimester placentas without histologic chorioamnionitis (n = 52) served as controls. Cases and controls were selected sequentially. Maternal medical records were reviewed for indicators of maternal infection.",
"Histologic chorioamnionitis was significantly associated with the usage of antibiotics (p = 0.0095) and a higher mean white blood cell count (p = 0.018). The presence of 1 or more clinical indicators was significantly associated with the presence of histologic chorioamnionitis (p = 0.019)."
],
"reasoning_required_pred": "yes",
"reasoning_free_pred": "yes",
"final_decision": "yes",
"LONG_ANSWER": "Histologic chorioamnionitis is a reliable indicator of infection whether or not it is clinically apparent."
},
Step 2: Data Preprocessing
Use the below script to convert the train/valid/test PubMedQA data into the JSONL format that NeMo needs for PEFT. In this example, we have named the below script, “preprocess_to_jsonl.py”, and placed it inside of the pubmedqa repository we have previously cloned.
import json
def read_jsonl (fname):
obj = []
with open(fname, 'rt') as f:
st = f.readline()
while st:
obj.append(json.loads(st))
st = f.readline()
return obj
def write_jsonl(fname, json_objs):
with open(fname, 'wt') as f:
for o in json_objs:
f.write(json.dumps(o)+"\n")
def form_question(obj):
st = ""
st += f"QUESTION:{obj['QUESTION']}\n"
st += "CONTEXT: "
for i, label in enumerate(obj['LABELS']):
st += f"{obj['CONTEXTS'][i]}\n"
st += f"TARGET: the answer to the question given the context is (yes|no|maybe): "
return st
def convert_to_jsonl(data_path, output_path):
data = json.load(open(data_path, 'rt'))
json_objs = []
for k in data.keys():
obj = data[k]
prompt = form_question(obj)
completion = obj['reasoning_required_pred']
json_objs.append({"input": prompt, "output": completion})
write_jsonl(output_path, json_objs)
return json_objs
def main():
test_json_objs = convert_to_jsonl("data/test_set.json", "pubmedqa_test.jsonl")
train_json_objs = convert_to_jsonl("data/pqal_fold0/train_set.json", "pubmedqa_train.jsonl")
dev_json_objs = convert_to_jsonl("data/pqal_fold0/dev_set.json", "pubmedqa_val.jsonl")
return test_json_objs, train_json_objs, dev_json_objs
if __name__ == "__main__":
main()
After running the above script, you will see the pubmedqa_train.jsonl
, pubmedqa_val.jsonl
, pubmedqa_test.jsonl
files appear in the directory you copied and ran the preprocessing script
$ ls pubmedqa/
data
evaluation.py
exp
get_human_performance.py
LICENSE
nemo_preprocess.py
preprocess
pubmedqa_test.jsonl
pubmedqa_train.jsonl
pubmedqa_val.jsonl
README.md
Below is what the formatting will look like once we have used the above script for converting the PubMedQA data into the format that NeMo expects for SFT and PEFT
{"input": "QUESTION: Failed IUD insertions in community practice: an under-recognized problem?\nCONTEXT: The data analysis was conducted to describe the rate of unsuccessful copper T380A intrauterine device (IUD) insertions among women using the IUD for emergency contraception (EC) at community family planning clinics in Utah.\n...",
"output": "yes"}
Step 1: Set the experiment configs
The megatron_gpt_peft_tuning_config.yaml file is used to configure the parameters for the running PEFT training jobs in NeMo with P-Tuning and LoRA techniques for language model tuning. Set the environment variables, pass the paths to your train, test and validation data files
MODEL="YOUR PATH TO llama2-7b.nemo"
TRAIN_DS="[YOUR PATH TO pubmedqa/pubmedqa_train.jsonl]"
VALID_DS="[YOUR PATH TO pubmedqa/pubmedqa_val.jsonl]"
TEST_DS="[YOUR PATH TO pubmedqa/pubmedqa_test.jsonl]"
TEST_NAMES="[pubmedqa]"
SCHEME="lora"
Set the concat sampling probability. This depends on the number of files being passed in the train set and how much percentage of the fine tuning data would you like to use from each file. Note sum of concat sampling probabilities should be 1.0. For example, the following is an example for setting concat sampling probability for a train set with 2 jsonl files
TRAIN_DS="[/path/to/dataset_1.jsonl,/path/to/dataset_2.jsonl]"
CONCAT_SAMPLING_PROBS="[0.3,0.7]"
In our example we are using 1 train file so CONCAT_SAMPLING_PROBS="[1.0]"
Step 2: Run PEFT training
Run the PEFT command by appropriately setting the values for the parameters such as the number of steps, model checkpoint path, batch sizes etc. For a full reference of parameter settings refer to the config file:
torchrun --nproc_per_node=8 \
/opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py \
trainer.devices=8 \
trainer.num_nodes=1 \
trainer.precision=bf16 \
trainer.val_check_interval=20 \
trainer.max_steps=50 \
model.megatron_amp_O2=False \
++model.mcore_gpt=True \
model.tensor_model_parallel_size=${TP_SIZE} \
model.pipeline_model_parallel_size=${PP_SIZE} \
model.micro_batch_size=1 \
model.global_batch_size=8 \
model.restore_from_path=${MODEL} \
model.data.train_ds.num_workers=0 \
model.data.validation_ds.num_workers=0 \
model.data.train_ds.file_names=${TRAIN_DS} \
model.data.train_ds.concat_sampling_probabilities=[1.0] \
model.data.validation_ds.file_names=${VALID_DS} \
model.peft.peft_scheme=${SCHEME}
Set $SCHEME="ptuning"
for ptuning instead of lora.
Change the following settings for 13b PEFT
model.tensor_model_parallel_size=2
model.pipeline_model_parallel_size=1
Change the following settings for 70b PEFT
model.tensor_model_parallel_size=8
model.pipeline_model_parallel_size=1
Step 3: Run inference
Set model.restore_from_path to the path for the llama2-7b.nemo model.
Set model.peft.restore_from_path to the path for the PEFT checkpoint that will be saved inside of your experiment directory.
Set model.test_ds.file_names to the path of the pubmedqa_test.jsonl file
Please configure $tokens_to_generate
and output_file_path_prefix
according to your project needs
python /opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py \
model.restore_from_path=${MODEL} \
model.peft.restore_from_path=${PATH_TO_TRAINED_MODEL} \
trainer.devices=8 \
model.data.test_ds.file_names=${TEST_DS} \
model.data.test_ds.names=${TEST_NAMES} \
model.data.test_ds.global_batch_size=32 \
model.data.test_ds.micro_batch_size=4 \
model.data.test_ds.tokens_to_generate=20 \
inference.greedy=True \
model.data.test_ds.output_file_path_prefix=${OUTPUT_PREFIX} \
model.data.test_ds.write_predictions_to_file=True
After the evaluation finishes, you should be able to see output similar to the following
────────────────────────────────────────────
Test metric DataLoader 0
────────────────────────────────────────────
test_loss 0.40366214513778687
test_loss_pubmedqa 0.40366214513778687
val_loss 0.40366214513778687