NeMo Framework PEFT with Llama2 and Mixtral-8x7B

Learning Goals

This playbook aims to demonstrate how to adapt or customize foundation models to improve performance on specific tasks.

This optimization process is known as fine-tuning, which involves adjusting the weights of a pre-trained foundation model with custom data.

Considering that foundation models can be significantly large, a variant of fine-tuning has gained traction recently known as parameter-efficient fine-tuning (PEFT). PEFT encompasses several methods, including P-Tuning, LoRA, Adapters, IA3, etc.

This playbook involves applying various PEFT methods to the Llama2 and Mixtral models. In this playbook you will implement and evaluate several parameter-efficient fine-tuning methods using a domain and task specific dataset. This playbook has been tested for P-Tuning and LoRA.

NeMo Tools and Resources

  1. NeMo Github repo

  2. NeMo Framework Training container: nvcr.io/nvidia/nemo:24.01.01.framework

Educational Resources

  1. Blog: Mastering LLM Techniques: Customization

  2. Whitepaper: LoRA: Low-Rank Adaptation of Large Language Models

  3. Blog: Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters

  4. NeMo documentation: Introduction to P-tuning

  5. NeMo notebook/tutorial: Introduction to p-tuning and prompt-tuning

Software Requirements

  1. Use the latest NeMo Framework Training container . Note that you must be logged in to the container registry to view this page.

  2. This playbook has been tested using the container: nvcr.io/nvidia/nemo:24.01.01.framework

Hardware Requirements

  1. Llama2 7B: minimum 1xA100 80G

  2. Llama2 13B: minimum 1xA100 80G

  3. Llama2 70B: minimum 4xA100 80G

  4. Mixtral 8x7B: minimum 4xA100 80G

  5. PEFT on all listed model sizes for both Llama and Mixtral can run on 1 DGX A100 80G node (8xA100 80G).

Data

This playbook uses the PubMedQA dataset. For more details about the data refer to PubMedQA: A Dataset for Biomedical Research Question Answering and PubMedQA

Convert Huggingface format to NeMo format

The following two optional sections show how to convert Llama2 or Mixtral models from Huggingface format to NeMo format.

If you already have a .nemo file for Llama2 models, you can skip this step.

Step 1: Download Llama2 in Huggingface format

First request download permission from both Huggingface and Meta. Create the destination directory. Then you can download by either first logging in

Copy
Copied!
            

mkdir llama2-7b-hf huggingface-cli login

Or use your huggingface API token to download using the following python code. Replace the value for token with your Huggingface token. In this example, Llama2 Huggingface model will be downloaded to ./llama2-7b-hf

Copy
Copied!
            

from huggingface_hub import snapshot_download snapshot_download( repo_id="meta-llama/Llama-2-7b-hf", local_dir="llama2-7b-hf", local_dir_use_symlinks=False, token=<YOUR HF TOKEN> )

You can also download the Huggingface git repository directly. For the following code, make sure the directory is named ./llama2-7b-hf. If you have warnings about git lfs, please refer to installation instructions from GitHub

Copy
Copied!
            

git clone https://huggingface.co/meta-llama/Llama-2-7b-hf mv Llama-2-7b-hf/ llama2-7b-hf/

Step 2: Convert to .nemo format

Run the container using the following command. Note that you may have to update which device(s) are available. For more information, see the Docker documentation

Copy
Copied!
            

docker run --gpus device=1 --shm-size=2g --net=host --ulimit memlock=-1 --rm -it -v ${PWD}:/workspace -w /workspace -v ${PWD}/results:/results nvcr.io/nvidia/nemo:24.01.01.framework bash

The previous command will open a bash shell within the container. Using this shell, convert the Huggingface model to a .nemo model.

Copy
Copied!
            

python /opt/NeMo/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py --in-file=./llama2-7b-hf/ --out-file=llama2-7b.nemo

Note that this process can take some time.

The generated llama2-7b.nemo file uses distributed checkpointing and can be loaded with any tensor parallel (TP) or pipeline parallel (PP) combination without reshaping/splitting.

If you already have a .nemo file for the Mixtral-8x7B model, you can skip this step.

Step 1: Download Mixtral-8x7B in Huggingface format

First, create the destination directory. Then you can download by either first logging in

Copy
Copied!
            

mkdir Mixtral-8x7B-v0.1 huggingface-cli login

Or use your huggingface API token to download using the following python code. Replace the value for token with your Huggingface token. In this example, Mixtral-8x7B Huggingface model will be downloaded to ./Mixtral-8x7B-v0.1

Copy
Copied!
            

from huggingface_hub import snapshot_download snapshot_download( repo_id="mistralai/Mixtral-8x7B-v0.1", local_dir="Mixtral-8x7B-v0.1", local_dir_use_symlinks=False )

You can also download the Huggingface git repository directly. If you have warnings about git lfs, please refer to installation instructions from GitHub

Copy
Copied!
            

git clone https://huggingface.co/mistralai/Mixtral-8x7B-v0.1

Step 2: Convert to .nemo format

Run the container using the following command. Note that you may have to update which device(s) are available. For more information, see the Docker documentation

Copy
Copied!
            

docker run --gpus device=1 --shm-size=2g --net=host --ulimit memlock=-1 --rm -it -v ${PWD}:/workspace -w /workspace -v ${PWD}/results:/results nvcr.io/nvidia/nemo:24.01.01.framework bash

The previous command will open a bash shell within the container. Using this shell, convert the Huggingface model to a .nemo model.

Copy
Copied!
            

python /opt/NeMo/scripts/nlp_language_modeling/convert_hf_mixtral_to_nemo.py --in-file=./Mixtral-8x7B-v0.1/ --out-file=Mixtral-8x7B-v0.1.nemo

Note that this process can take some time.

The generated Mixtral-8x7B-v0.1.nemo file uses distributed checkpointing and can be loaded with any tensor parallel (TP) or pipeline parallel (PP) combination without reshaping/splitting.

Step 1: Download the PubMedQA dataset and run the split_dataset.py script in the cloned directory

Download the dataset

Copy
Copied!
            

git clone https://github.com/pubmedqa/pubmedqa.git cd pubmedqa cd preprocess python split_dataset.py pqal

After running the split_dataset.py script, you will see the test_set as well as ten different directories which each contains a different train/validation fold

Copy
Copied!
            

$ cd ../.. $ ls pubmedqa/data/ -1v ori_pqal.json pqal_fold0 pqal_fold1 pqal_fold2 pqal_fold3 pqal_fold4 pqal_fold5 pqal_fold6 pqal_fold7 pqal_fold8 pqal_fold9 test_ground_truth.json test_set.json

Below is an example of what a single row looks like in the train_set.json file inside of the PubMedQA train, validation and test splits.

Copy
Copied!
            

"18251357": { "QUESTION": "Does histologic chorioamnionitis correspond to clinical chorioamnionitis?", "CONTEXTS": [ "To evaluate the degree to which histologic chorioamnionitis, a frequent finding in placentas submitted for histopathologic evaluation, correlates with clinical indicators of infection in the mother.", "A retrospective review was performed on 52 cases with a histologic diagnosis of acute chorioamnionitis from 2,051 deliveries at University Hospital, Newark, from January 2003 to July 2003. Third-trimester placentas without histologic chorioamnionitis (n = 52) served as controls. Cases and controls were selected sequentially. Maternal medical records were reviewed for indicators of maternal infection.", "Histologic chorioamnionitis was significantly associated with the usage of antibiotics (p = 0.0095) and a higher mean white blood cell count (p = 0.018). The presence of 1 or more clinical indicators was significantly associated with the presence of histologic chorioamnionitis (p = 0.019)." ], "reasoning_required_pred": "yes", "reasoning_free_pred": "yes", "final_decision": "yes", "LONG_ANSWER": "Histologic chorioamnionitis is a reliable indicator of infection whether or not it is clinically apparent." },

Step 2: Data Preprocessing

Use the below script to convert the train/validation/test PubMedQA data into the JSONL format that NeMo needs for PEFT. In this example, we have named the below script, preprocess_to_jsonl.py, and placed it inside of the pubmedqa repository we have previously cloned.

Copy
Copied!
            

import json def read_jsonl (fname): obj = [] with open(fname, 'rt') as f: st = f.readline() while st: obj.append(json.loads(st)) st = f.readline() return obj def write_jsonl(fname, json_objs): with open(fname, 'wt') as f: for o in json_objs: f.write(json.dumps(o)+"\n") def form_question(obj): st = "" st += f"QUESTION:{obj['QUESTION']}\n" st += "CONTEXT: " for i, label in enumerate(obj['LABELS']): st += f"{obj['CONTEXTS'][i]}\n" st += f"TARGET: the answer to the question given the context is (yes|no|maybe): " return st def convert_to_jsonl(data_path, output_path): data = json.load(open(data_path, 'rt')) json_objs = [] for k in data.keys(): obj = data[k] prompt = form_question(obj) completion = obj['reasoning_required_pred'] json_objs.append({"input": prompt, "output": completion}) write_jsonl(output_path, json_objs) return json_objs def main(): test_json_objs = convert_to_jsonl("data/test_set.json", "pubmedqa_test.jsonl") train_json_objs = convert_to_jsonl("data/pqal_fold0/train_set.json", "pubmedqa_train.jsonl") dev_json_objs = convert_to_jsonl("data/pqal_fold0/dev_set.json", "pubmedqa_val.jsonl") return test_json_objs, train_json_objs, dev_json_objs if __name__ == "__main__": main()

You can run this script with the following

Copy
Copied!
            

cd pubmedqa python preprocess_to_jsonl.py

After running the above script, you will see the pubmedqa_train.jsonl, pubmedqa_val.jsonl, pubmedqa_test.jsonl files appear in the directory you copied and ran the preprocessing script

Copy
Copied!
            

$ cd .. $ ls pubmedqa -1v LICENSE README.md data evaluation.py get_human_performance.py preprocess preprocess_to_jsonl.py pubmedqa_test.jsonl pubmedqa_train.jsonl pubmedqa_val.jsonl

Below is what the formatting will look like once we have used the above script for converting the PubMedQA data into the format that NeMo expects for PEFT

Copy
Copied!
            

{ "input": "QUESTION: Failed IUD insertions in community practice: an under-recognized problem?\nCONTEXT: The data analysis was conducted to describe the rate of unsuccessful copper T380A intrauterine device (IUD) insertions among women using the IUD for emergency contraception (EC) at community family planning clinics in Utah.\n...", "output": "yes" }

The megatron_gpt_finetuning_config.yaml file is used to configure the parameters for running the PEFT training jobs in NeMo with P-Tuning and LoRA techniques for language model tuning.

Next, we will create a shell script to run the fine-tuning. This script also contains all the environment variables for successful execution.

The environment variables specified at the top of the script assume you are at the root of the directory that contains both ./llama2-7b.nemo (or ./mixtral-8x7B.nemo, ./llama2-13b.nemo, ./llama2-70b.nemo, etc.) and the pubmedqa directory. Notice that some of the values are arrays and must start and end with square brackets “[]” (eg. TRAIN_DS and VALID_DS).

To run the examples against various Llama2 model sizes or Mixtral-8x7B, consult the following possible variable values. Note that different datasets might require different configurations.

Model size

TP_SIZE

PP_SIZE

GPU_COUNT

Llama2 7B 1 1 1
Llama2 13B 1 1 1
Llama2 70B 4 1 4
Mixtral-8x7B 4 1 4

Save the following to run_peft.sh

Copy
Copied!
            

# This is the nemo model we are finetuning # Change this to match the model you want to finetune MODEL="./llama2-7b.nemo" # These are the training datasets (in our case we only have one) TRAIN_DS="[pubmedqa/pubmedqa_train.jsonl]" # These are the validation datasets (in our case we only have one) VALID_DS="[pubmedqa/pubmedqa_val.jsonl]" # These are the test datasets (in our case we only have one) TEST_DS="[pubmedqa/pubmedqa_test.jsonl]" # These are the names of the test datasets TEST_NAMES="[pubmedqa]" # This is the PEFT scheme that we will be using. Set to "ptuning" for P-Tuning instead of LoRA PEFT_SCHEME="lora" # This is the concat sampling probability. This depends on the number of files being passed in the train set # and the sampling probability for each file. In our case, we have one training file. Note sum of concat sampling # probabilities should be 1.0. For example, with two entries in TRAIN_DS, CONCAT_SAMPLING_PROBS might be # "[0.3,0.7]". For three entries, CONCAT_SAMPLING_PROBS might be "[0.3,0.1,0.6]" # NOTE: Your entry must contain a value greater than 0.0 for each file CONCAT_SAMPLING_PROBS="[1.0]" # This is the tensor parallel size (splitting tensors among GPUs horizontally) # See above matrix for proper value for the given model size TP_SIZE=1 # This is the pipeline parallel size (splitting layers among GPUs vertically) # See above matrix for proper value for the given model size PP_SIZE=1 # The number of nodes to run this on # See above matrix for proper value for the given model size NODE_COUNT=1 # The number of total GPUs used GPU_COUNT=1 # Where to store the finetuned model and training artifacts OUTPUT_DIR="./results" # Run the PEFT command by appropriately setting the values for the parameters such as the number of steps, # model checkpoint path, batch sizes etc. For a full reference of parameter # settings refer to the config at https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml python /opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \ trainer.log_every_n_steps=1 \ trainer.precision=bf16 \ trainer.devices=${GPU_COUNT} \ trainer.num_nodes=1 \ trainer.val_check_interval=20 \ trainer.max_steps=50 \ model.restore_from_path=${MODEL} \ model.peft.peft_scheme=${PEFT_SCHEME} \ model.micro_batch_size=1 \ model.global_batch_size=128 \ model.tensor_model_parallel_size=${TP_SIZE} \ model.pipeline_model_parallel_size=${PP_SIZE} \ model.megatron_amp_O2=True \ model.activations_checkpoint_granularity=selective \ model.activations_checkpoint_num_layers=null \ model.activations_checkpoint_method=uniform \ model.optim.name=fused_adam \ model.optim.lr=1e-4 \ model.answer_only_loss=True \ model.data.train_ds.file_names=${TRAIN_DS} \ model.data.validation_ds.file_names=${VALID_DS} \ model.data.test_ds.file_names=${TEST_DS} \ model.data.train_ds.concat_sampling_probabilities=${CONCAT_SAMPLING_PROBS} \ model.data.train_ds.max_seq_length=2048 \ model.data.validation_ds.max_seq_length=2048 \ model.data.train_ds.micro_batch_size=1 \ model.data.train_ds.global_batch_size=128 \ model.data.validation_ds.micro_batch_size=1 \ model.data.validation_ds.global_batch_size=128 \ model.data.train_ds.num_workers=0 \ model.data.validation_ds.num_workers=0 \ model.data.test_ds.num_workers=0 \ model.data.validation_ds.metric.name=loss \ model.data.test_ds.metric.name=loss \ exp_manager.create_wandb_logger=False \ exp_manager.checkpoint_callback_params.mode=min \ exp_manager.explicit_log_dir=${OUTPUT_DIR} \ exp_manager.resume_if_exists=True \ exp_manager.resume_ignore_no_checkpoint=True \ exp_manager.create_checkpoint_callback=True \ exp_manager.checkpoint_callback_params.monitor=validation_loss \ ++exp_manager.checkpoint_callback_params.save_best_model=False \ exp_manager.checkpoint_callback_params.save_nemo_on_train_end=True \ model.save_nemo_on_validation_end=False

If you have not launched the NeMo Framework container, run the container using the following command. Note that you may have to update which device(s) are available. For more information, see the Docker documentation

Copy
Copied!
            

docker run --gpus device=1 --shm-size=2g --net=host --ulimit memlock=-1 --rm -it -v ${PWD}:/workspace -w /workspace -v ${PWD}/results:/results nvcr.io/nvidia/nemo:24.01.01.framework bash

Within the container, execute run_peft.sh you saved earlier:

Copy
Copied!
            

bash run_peft.sh

Step 3: Run evaluation

Save the following to run_evaluation.sh, which sets the correct config values (the following values assume the eval will be run from the root directory that contains the fine-tuning results directory and the pubmedqa directory) and runs evaluation. The output of this script is just the loss values. We recommend designing a dedicated evaluation methodology to understand the real world behavior of the model. For example, some auto eval metrics include ROUGE, BLEU, F1, etc.

Copy
Copied!
            

# The number of total GPUs available GPU_COUNT=1 # Change this to the nemo model you want to use MODEL="./llama2-7b.nemo" # This will live in whatever $OUTPUT_DIR was set to in the training script above # The filename will match whichever peft scheme was used during training PATH_TO_TRAINED_MODEL="./results/checkpoints/megatron_gpt_peft_lora_tuning.nemo" # The test dataset TEST_DS="pubmedqa/pubmedqa_test.jsonl" TEST_NAMES="pubmedqa" # This is the prefix, including the path and filename prefix, for the accuracy file output # This will be combined with TEST_NAMES to create the file ./results/peft_results_test_pubmedqa_inputs_preds_labels.jsonl OUTPUT_PREFIX="./results/peft_results" TOKENS_TO_GENERATE=20 # This is the tensor parallel size (splitting tensors among GPUs horizontally) TP_SIZE=1 # This is the pipeline parallel size (splitting layers among GPUs vertically) PP_SIZE=1 # Execute python /opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_generate.py \ model.restore_from_path=${MODEL} \ model.peft.restore_from_path=${PATH_TO_TRAINED_MODEL} \ trainer.devices=${GPU_COUNT} \ model.data.test_ds.file_names=[${TEST_DS}] \ model.data.test_ds.names=[${TEST_NAMES}] \ model.data.test_ds.global_batch_size=4 \ model.data.test_ds.micro_batch_size=1 \ model.data.test_ds.tokens_to_generate=${TOKENS_TO_GENERATE} \ model.tensor_model_parallel_size=${TP_SIZE} \ model.megatron_amp_O2=True \ model.pipeline_model_parallel_size=${PP_SIZE} \ inference.greedy=True \ model.data.test_ds.output_file_path_prefix=${OUTPUT_PREFIX} \ model.answer_only_loss=True \ model.data.test_ds.write_predictions_to_file=True

Run the script

Copy
Copied!
            

bash run_evaluation.sh

After the evaluation finishes, you should be able to see output similar but not identical to the following, and the numbers might be different between the metrics depending on the datasets

Copy
Copied!
            

──────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────── test_loss 0.3883868455886841

[Optional] Step 4: Calculate Metrics

Save the following to run_accuracy_metric_calculation.py, which will calculate the accuracy metrics for the model and dataset.

Copy
Copied!
            

import json answers = [] with open("./results/peft_results_test_pubmedqa_inputs_preds_labels.jsonl",'rt') as f: st = f.readline() while st: answers.append(json.loads(st)) st = f.readline() data_test = json.load(open("./pubmedqa/data/test_set.json",'rt')) results = {} sample_id = list(data_test.keys()) for i, key in enumerate(sample_id): answer = answers[i]['pred'] if 'yes' in answer: results[key] = 'yes' elif 'no' in answer: results[key] = 'no' elif 'maybe' in answer: results[key] = 'maybe' else: print("Malformed answer: ", answer) results[key] = 'maybe' # dump results FILENAME="pubmedqa-peft-accuracy-results.json" with(open(FILENAME, "w")) as f: json.dump(results, f)

Run the script

Copy
Copied!
            

python run_accuracy_metric_calculation.py cd pubmedqa python evaluation.py ../pubmedqa-peft-accuracy-results.json

After the accuracy calculation finishes, you should be able to see output similar to the following (the numbers may differ). The below sample scores can be improved by training the model further and performing hyperparameter tuning. In this playbook, we only train for 50 steps.

Copy
Copied!
            

──────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────── Accuracy 0.736000 Macro-F1 0.508380

Previous NeMo Framework SFT with Mixtral-8x7B
Next NeMo Framework SFT with Mistral-7B
© Copyright 2023-2024, NVIDIA. Last updated on May 3, 2024.