NeMo Framework PEFT with Mixtral-8x7B

Learning Goals

This project aims to demonstrate how to adapt or customize foundation models to improve performance on specific tasks.

This optimization process is known as fine-tuning, which involves adjusting the weights of a pre-trained foundation model with custom data.

Considering that foundation models can be significantly large, a variant of fine-tuning has gained traction recently, known as parameter-efficient fine-tuning (PEFT). PEFT encompasses several methods, including P-Tuning, LoRA, Adapters, and IA3.

For those interested in a deeper understanding of these methods, we have included a list of additional material in the Education Resources section.

This project involves applying various fine-tuning methods to Mixtral-8x7B model. In this playbook you will implement and evaluate several parameter-efficient fine-tuning methods using a domain and task specific dataset. This playbook has been tested for P-Tuning and LoRA.

NeMo Tools and Resources

  1. NeMo Github repo

Educational Resources

  1. Blog: Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters

  2. NeMo documentation: Introduction to P-tuning

  3. NeMo notebook/tutorial: Introduction to p-tuning and prompt-tuning

Software Requirements

  1. Use the latest NeMo Framework Training container

  2. This playbook has been tested on: nvcr.io/nvidia/nemo:24.03.framework

Hardware Requirements

  1. Minimum 2xA100 80G (or equivalent) for PEFT. This playbook has been tested on 8xA100 40G.

If you already have a .nemo file for Mixtral-8x7B, you can skip this step.

Step 1: Download Mixtral-8x7B in huggingface format

First create the destination directory, then you can download by either using the CLI tool

Copy
Copied!
            

mkdir mixtral-8x7B-hf huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir mixtral-8x7B-hf

Or if you prefer using the huggingface API you can download the checkpoint using the following Python code

Copy
Copied!
            

from huggingface_hub import snapshot_download snapshot_download(repo_id="mistralai/Mixtral-8x7B-v0.1", local_dir="mixtral-8x7B-hf", local_dir_use_symlinks=False)

In this example, Mixtral-8x7B huggingface model will be downloaded to ./mixtral-8x7B-hf

Step 2: Convert to .nemo

Run the container using the following command

Copy
Copied!
            

docker run --gpus device=1 --shm-size=2g --net=host --ulimit memlock=-1 --rm -it -v ${PWD}:/workspace -w /workspace -v ${PWD}/results:/results nvcr.io/nvidia/nemo:24.03.framework bash

Convert the huggingface model to .nemo model

Copy
Copied!
            

torchrun --nproc_per_node=1 /opt/NeMo/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py --input_name_or_path=./mixtral-8x7B-hf/ --output_path=mixtral.nemo

The generated mixtral.nemo file uses distributed checkpointing and can be loaded with any tensor parallel (tp) or pipeline parallel (pp) combination without modifying (e.g. reshaping/splitting) the mixtral.nemo checkpoint.

Step 1: Download the PubMedQA dataset and run the split_dataset.py script in the cloned directory.

Download the dataset

Copy
Copied!
            

git clone https://github.com/pubmedqa/pubmedqa.git cd pubmedqa cd preprocess python split_dataset.py pqal

After running the split_dataset.py script, you will see the test_set as well as ten different directories which each contains a different train/validation fold

Copy
Copied!
            

$ cd ../.. $ ls pubmedqa/data/ ori_pqal.json pqal_fold0 pqal_fold1 pqal_fold2 pqal_fold3 pqal_fold4 pqal_fold5 pqal_fold6 pqal_fold7 pqal_fold8 pqal_fold9 test_ground_truth.json test_set.json

Below is an example of what the objects look like inside of the PubMedQA train, validation and test splits

Copy
Copied!
            

"18251357": { "QUESTION": "Does histologic chorioamnionitis correspond to clinical chorioamnionitis?", "CONTEXTS": [ "To evaluate the degree to which histologic chorioamnionitis, a frequent finding in placentas submitted for histopathologic evaluation, correlates with clinical indicators of infection in the mother.", "A retrospective review was performed on 52 cases with a histologic diagnosis of acute chorioamnionitis from 2,051 deliveries at University Hospital, Newark, from January 2003 to July 2003. Third-trimester placentas without histologic chorioamnionitis (n = 52) served as controls. Cases and controls were selected sequentially. Maternal medical records were reviewed for indicators of maternal infection.", "Histologic chorioamnionitis was significantly associated with the usage of antibiotics (p = 0.0095) and a higher mean white blood cell count (p = 0.018). The presence of 1 or more clinical indicators was significantly associated with the presence of histologic chorioamnionitis (p = 0.019)." ], "reasoning_required_pred": "yes", "reasoning_free_pred": "yes", "final_decision": "yes", "LONG_ANSWER": "Histologic chorioamnionitis is a reliable indicator of infection whether or not it is clinically apparent." },

Step 2: Data Preprocessing

Use the below script to convert the train/valid/test PubMedQA data into the JSONL format that NeMo needs for PEFT. In this example, we have named the below script, “preprocess_to_jsonl.py”, and placed it inside of the pubmedqa repository we have previously cloned.

Copy
Copied!
            

import json def read_jsonl (fname): obj = [] with open(fname, 'rt') as f: st = f.readline() while st: obj.append(json.loads(st)) st = f.readline() return obj def write_jsonl(fname, json_objs): with open(fname, 'wt') as f: for o in json_objs: f.write(json.dumps(o)+"\n") def form_question(obj): st = "" st += f"QUESTION:{obj['QUESTION']}\n" st += "CONTEXT: " for i, label in enumerate(obj['LABELS']): st += f"{obj['CONTEXTS'][i]}\n" st += f"TARGET: the answer to the question given the context is (yes|no|maybe): " return st def convert_to_jsonl(data_path, output_path): data = json.load(open(data_path, 'rt')) json_objs = [] for k in data.keys(): obj = data[k] prompt = form_question(obj) completion = obj['reasoning_required_pred'] json_objs.append({"input": prompt, "output": completion}) write_jsonl(output_path, json_objs) return json_objs def main(): test_json_objs = convert_to_jsonl("data/test_set.json", "pubmedqa_test.jsonl") train_json_objs = convert_to_jsonl("data/pqal_fold0/train_set.json", "pubmedqa_train.jsonl") dev_json_objs = convert_to_jsonl("data/pqal_fold0/dev_set.json", "pubmedqa_val.jsonl") return test_json_objs, train_json_objs, dev_json_objs if __name__ == "__main__": main()

After running the above script, you will see the pubmedqa_train.jsonl, pubmedqa_val.jsonl, pubmedqa_test.jsonl files appear in the directory you copied and ran the preprocessing script

Copy
Copied!
            

$ ls pubmedqa/ data evaluation.py exp get_human_performance.py LICENSE nemo_preprocess.py preprocess pubmedqa_test.jsonl pubmedqa_train.jsonl pubmedqa_val.jsonl README.md

Below is what the formatting will look like once we have used the above script for converting the PubMedQA data into the format that NeMo expects for SFT and PEFT

Copy
Copied!
            

{"input": "QUESTION: Failed IUD insertions in community practice: an under-recognized problem?\nCONTEXT: The data analysis was conducted to describe the rate of unsuccessful copper T380A intrauterine device (IUD) insertions among women using the IUD for emergency contraception (EC) at community family planning clinics in Utah.\n...", "output": "yes"}

Step 1: Set the experiment configs

The megatron_gpt_finetuning_config.yaml file is used to configure the parameters for the running PEFT training jobs in NeMo with P-Tuning and LoRA techniques for language model tuning. Set the environment variables, pass the paths to your train, test and validation data files

Copy
Copied!
            

MODEL="YOUR PATH TO Mixtral-8x7B-7b.nemo" TRAIN_DS="[YOUR PATH TO pubmedqa/pubmedqa_train.jsonl]" VALID_DS="[YOUR PATH TO pubmedqa/pubmedqa_val.jsonl]" TEST_DS="[YOUR PATH TO pubmedqa/pubmedqa_test.jsonl]" TEST_NAMES="[pubmedqa]" SCHEME="lora"

Set the concat sampling probability. This depends on the number of files being passed in the train set and how much percentage of the fine tuning data would you like to use from each file. Note sum of concat sampling probabilities should be 1.0. For example, the following is an example for setting concat sampling probability for a train set with 2 jsonl files

Copy
Copied!
            

TRAIN_DS="[/path/to/dataset_1.jsonl,/path/to/dataset_2.jsonl]" CONCAT_SAMPLING_PROBS="[0.3,0.7]"

In our example we are using 1 train file so CONCAT_SAMPLING_PROBS="[1.0]"

Step 2: Run PEFT training

Run the PEFT command by appropriately setting the values for the parameters such as the number of steps, model checkpoint path, batch sizes etc. For a full reference of parameter settings refer to the config file:

Copy
Copied!
            

TP_SIZE=8 \ PP_SIZE=1 \ SCHEME="lora" \ torchrun --nproc_per_node=8 \ /opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \ trainer.devices=8 \ trainer.num_nodes=1 \ trainer.precision=bf16 \ trainer.val_check_interval=20 \ trainer.max_steps=50 \ model.megatron_amp_O2=True \ ++model.mcore_gpt=True \ model.tensor_model_parallel_size=${TP_SIZE} \ model.pipeline_model_parallel_size=${PP_SIZE} \ model.micro_batch_size=1 \ model.global_batch_size=32 \ model.optim.lr=1e-4 \ model.restore_from_path=${MODEL} \ model.data.train_ds.num_workers=0 \ model.data.validation_ds.num_workers=0 \ model.data.train_ds.file_names=${TRAIN_DS} \ model.data.train_ds.concat_sampling_probabilities=[1.0] \ model.data.validation_ds.file_names=${VALID_DS} \ model.peft.peft_scheme=${SCHEME} \ model.peft.lora_tuning.target_modules=[attention_qkv] \ exp_manager.checkpoint_callback_params.mode=min

You can also set SCHEME="ptuning" for ptuning instead of LoRA.

Step 3: Run inference

  1. Set model.restore_from_path to the path for the Mixtral-8x7B-7b.nemo model.

  2. Set model.peft.restore_from_path to the path for the PEFT checkpoint that will be saved inside of your experiment directory.

  3. Set model.test_ds.file_names to the path of the pubmedqa_test.jsonl file

  4. Set Pipeline Parallelism ($PP_SIZE) and Tensor Parallelism ($TP_SIZE) to match training values.

Please configure tokens_to_generate and output_file_path_prefix according to your project needs

Copy
Copied!
            

TP_SIZE=8 \ PP_SIZE=1 \ python3 \ /opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py \ model.restore_from_path=${MODEL} \ model.peft.restore_from_path=${PATH_TO_TRAINED_MODEL} \ trainer.devices=8 \ model.tensor_model_parallel_size=${TP_SIZE} \ model.pipeline_model_parallel_size=${PP_SIZE} \ model.data.test_ds.file_names=${TEST_DS} \ model.data.test_ds.names=${TEST_NAMES} \ model.global_batch_size=32 \ model.micro_batch_size=4 \ model.data.test_ds.tokens_to_generate=20 \ inference.greedy=True \ model.data.test_ds.output_file_path_prefix=${OUTPUT_PREFIX} \ model.data.test_ds.write_predictions_to_file=True

After the evaluation finishes, you should be able to see output similar to the following

Copy
Copied!
            

──────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────── test_loss 0.40366214513778687 test_loss_pubmedqa 0.40366214513778687 val_loss 0.40366214513778687

Previous NeMo Framework Supervised fine-tuning (SFT) with Mixtral-8x7B
Next NeMo Framework Supervised fine-tuning (SFT) with Mistral-7B
© Copyright 2023-2024, NVIDIA. Last updated on Apr 8, 2024.