NeMo Framework PEFT with Mistral-7B

User Guide (Latest Version)

Learning Goals

This project aims to demonstrate how to adapt or customize foundation models to improve performance on specific tasks.

This optimization process is known as fine-tuning, which involves adjusting the weights of a pre-trained foundation model with custom data.

Considering that foundation models can be significantly large, a variant of fine-tuning has gained traction recently, known as parameter-efficient fine-tuning (PEFT). PEFT encompasses several methods, including P-Tuning, LoRA, Adapters, and IA3.

For those interested in a deeper understanding of these methods, we have included a list of additional material in the Educational Resources section.

This project involves applying various fine-tuning methods to Mistral-7B model. In this playbook you will implement and evaluate several parameter-efficient fine-tuning methods using a domain and task specific dataset. This playbook has been tested for P-Tuning and LoRA.

NeMo Tools and Resources

Educational Resources

Software Requirements

  • Use the latest NeMo Framework Training container

  • This playbook has been tested on: nvcr.io/nvidia/nemo:24.05. It is expected to work similarly on other environments.

Hardware Requirements

  • Minimum 1xA100 80G for PEFT on Mistral-7B

  • This playbook has been tested on 8xA100 80G

If you already have a .nemo file for Mistral-7B model, you can skip this step.

Step 1: Download Mistral-7B in Hugging Face format

Request download permission and create the destination directory. Two options are available.

To download using the CLI tool:

Copy
Copied!
            

mkdir mistral-7B-hf huggingface-cli download mistralai/Mistral-7B-v0.1 --local-dir mistral-7B-hf

To download using the Hugging Face API, run the following Python code:

Copy
Copied!
            

from huggingface_hub import snapshot_download snapshot_download(repo_id="mistralai/Mistral-7B-v0.1", local_dir="mistral-7B-hf", local_dir_use_symlinks=False)

In this example, the Mistral-7B Hugging Face model will be downloaded to ./mistral-7B-hf.

Step 2: Convert to .nemo

Run the container using the following command:

Copy
Copied!
            

docker run --gpus device=1 --shm-size=2g --net=host --ulimit memlock=-1 --rm -it -v ${PWD}:/workspace -w /workspace -v ${PWD}/results:/results nvcr.io/nvidia/nemo:24.05 bash

Convert the Hugging Face model to .nemo model:

Copy
Copied!
            

python3 /opt/NeMo/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py --input_name_or_path=./mistral-7B-hf/ --output_path=mistral.nemo

The generated mistral.nemo file uses distributed checkpointing and can be loaded with any tensor parallel (tp) or pipeline parallel (pp) combination without modifying (e.g. reshaping/splitting) the mistral.nemo checkpoint.

Step 1: Download the PubMedQA dataset and run the split_dataset.py script in the cloned directory

Download the dataset:

Copy
Copied!
            

git clone https://github.com/pubmedqa/pubmedqa.git cd pubmedqa cd preprocess python split_dataset.py pqal

After running the split_dataset.py script, you will see the test_set as well as ten different directories which each contain a different train/validation fold:

Copy
Copied!
            

$ cd ../.. $ ls pubmedqa/data/ ori_pqal.json pqal_fold0 pqal_fold1 pqal_fold2 pqal_fold3 pqal_fold4 pqal_fold5 pqal_fold6 pqal_fold7 pqal_fold8 pqal_fold9 test_ground_truth.json test_set.json

The following example shows what a single row looks like in the train_set.json file inside of the PubMedQA train, validation, and test splits.

Copy
Copied!
            

"18251357": { "QUESTION": "Does histologic chorioamnionitis correspond to clinical chorioamnionitis?", "CONTEXTS": [ "To evaluate the degree to which histologic chorioamnionitis, a frequent finding in placentas submitted for histopathologic evaluation, correlates with clinical indicators of infection in the mother.", "A retrospective review was performed on 52 cases with a histologic diagnosis of acute chorioamnionitis from 2,051 deliveries at University Hospital, Newark, from January 2003 to July 2003. Third-trimester placentas without histologic chorioamnionitis (n = 52) served as controls. Cases and controls were selected sequentially. Maternal medical records were reviewed for indicators of maternal infection.", "Histologic chorioamnionitis was significantly associated with the usage of antibiotics (p = 0.0095) and a higher mean white blood cell count (p = 0.018). The presence of 1 or more clinical indicators was significantly associated with the presence of histologic chorioamnionitis (p = 0.019)." ], "reasoning_required_pred": "yes", "reasoning_free_pred": "yes", "final_decision": "yes", "LONG_ANSWER": "Histologic chorioamnionitis is a reliable indicator of infection whether or not it is clinically apparent." },

Step 2: Data preprocessing

Use the following script to convert the train, validation, and test PubMedQA data into the JSONL format that NeMo needs for PEFT. In this example, we have named the script, “preprocess_to_jsonl.py,” and placed it inside of the pubmedqa repository we previously cloned.

Copy
Copied!
            

import json def read_jsonl (fname): obj = [] with open(fname, 'rt') as f: st = f.readline() while st: obj.append(json.loads(st)) st = f.readline() return obj def write_jsonl(fname, json_objs): with open(fname, 'wt') as f: for o in json_objs: f.write(json.dumps(o)+"\n") def form_question(obj): st = "" st += f"QUESTION:{obj['QUESTION']}\n" st += "CONTEXT: " for i, label in enumerate(obj['LABELS']): st += f"{obj['CONTEXTS'][i]}\n" st += f"TARGET: the answer to the question given the context is (yes|no|maybe): " return st def convert_to_jsonl(data_path, output_path): data = json.load(open(data_path, 'rt')) json_objs = [] for k in data.keys(): obj = data[k] prompt = form_question(obj) completion = obj['reasoning_required_pred'] json_objs.append({"input": prompt, "output": completion}) write_jsonl(output_path, json_objs) return json_objs def main(): test_json_objs = convert_to_jsonl("data/test_set.json", "pubmedqa_test.jsonl") train_json_objs = convert_to_jsonl("data/pqal_fold0/train_set.json", "pubmedqa_train.jsonl") dev_json_objs = convert_to_jsonl("data/pqal_fold0/dev_set.json", "pubmedqa_val.jsonl") return test_json_objs, train_json_objs, dev_json_objs if __name__ == "__main__": main()

After running the script, you will see the pubmedqa_train.jsonl, pubmedqa_val.jsonl, pubmedqa_test.jsonl files appear in the directory in which you copied and ran the preprocessing script.

Copy
Copied!
            

$ ls pubmedqa/ data evaluation.py exp get_human_performance.py LICENSE nemo_preprocess.py preprocess pubmedqa_test.jsonl pubmedqa_train.jsonl pubmedqa_val.jsonl README.md

The following example shows what the formatting will look like after the script has converted the PubMedQA data into the format that NeMo expects for PEFT:

Copy
Copied!
            

{"input": "QUESTION: Failed IUD insertions in community practice: an under-recognized problem?\nCONTEXT: The data analysis was conducted to describe the rate of unsuccessful copper T380A intrauterine device (IUD) insertions among women using the IUD for emergency contraception (EC) at community family planning clinics in Utah.\n...", "output": "yes"}

Step 3. Run the PEFT fine-tuning script

The NeMo framework supports several fine-tuning techniques (e.g. P-Tuning, LoRA, etc.) that can be used to adapt a base model. In this section, we show how to adapt the Mistral-7B model on the PubMedQA dataset using the LoRA technique.

To this end, we will use the following bash script (filename: run_peft.sh). The bash script contains paths to the .nemo checkpoint and the dataset, as well as other PEFT hyperparameters such as batch-size, learning-rate, etc. These hyperparameters can be passed via CLI or as a config file, for a full reference of the default parameters, please refer to the config file.

Save the following as a bash script with filename run_peft.sh and run it using bash run_peft.sh.

Copy
Copied!
            

# This is the nemo model we are fine-tuning # This should point to the mistral.nemo as created in the checkpoint conversion script. MODEL="./mistral.nemo" # These are the training datasets (in our case we only have one) TRAIN_DS="[pubmedqa/pubmedqa_train.jsonl]" # These are the validation datasets (in our case we only have one) VALID_DS="[pubmedqa/pubmedqa_val.jsonl]" # These are the test datasets (in our case we only have one) TEST_DS="[pubmedqa/pubmedqa_test.jsonl]" # These are the names of the test datasets TEST_NAMES="[pubmedqa]" # This is the PEFT scheme that we will be using. Set to "ptuning" for P-Tuning instead of LoRA PEFT_SCHEME="lora" # This is the concat sampling probability. This depends on the number of files being passed in the train set # and the sampling probability for each file. In our case, we have one training file. Note sum of concat sampling # probabilities should be 1.0. For example, with two entries in TRAIN_DS, CONCAT_SAMPLING_PROBS might be # "[0.3,0.7]". For three entries, CONCAT_SAMPLING_PROBS might be "[0.3,0.1,0.6]" # NOTE: Your entry must contain a value greater than 0.0 for each file CONCAT_SAMPLING_PROBS="[1.0]" # Tensor-parallelism=4 TP_SIZE=8 # Pipeline-parallelism=1 PP_SIZE=1 python3 \ /opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \ trainer.devices=8 \ trainer.num_nodes=1 \ trainer.precision=bf16 \ trainer.val_check_interval=20 \ trainer.max_steps=50 \ model.megatron_amp_O2=True \ ++model.mcore_gpt=True \ model.tensor_model_parallel_size=${TP_SIZE} \ model.pipeline_model_parallel_size=${PP_SIZE} \ model.micro_batch_size=1 \ model.global_batch_size=32 \ model.optim.lr=1e-4 \ model.restore_from_path=${MODEL} \ model.data.train_ds.num_workers=0 \ model.data.validation_ds.num_workers=0 \ model.data.train_ds.file_names=${TRAIN_DS} \ model.data.train_ds.concat_sampling_probabilities=${CONCAT_SAMPLING_PROBS} \ model.data.validation_ds.file_names=${VALID_DS} \ model.peft.peft_scheme=${PEFT_SCHEME} \ model.peft.lora_tuning.target_modules=[attention_qkv] \ exp_manager.checkpoint_callback_params.mode=min

You can also set SCHEME="ptuning" for p-tuning instead of LoRA.

Step 4. Run inference

The NeMo framework allows users to evaluate their trained model, as shown in the following script (named run_evaluation.sh). In this example, we use greedy decoding and generate up to 20 tokens. After saving the file, you can execute it using bash run_evaluation.sh.

Copy
Copied!
            

# Change this to the nemo model you want to use MODEL="./mistral.nemo" # This will live in whatever $OUTPUT_DIR was set to in the training script above # The filename will match whichever peft scheme was used during training PATH_TO_TRAINED_MODEL="./nemo_experiments/megatron_gpt_peft_lora_tuning/checkpoints/megatron_gpt_peft_lora_tuning.nemo" # These are the test datasets (in our case we only have one) TEST_DS="[pubmedqa/pubmedqa_test.jsonl]" # These are the names of the test datasets TEST_NAMES="[pubmedqa]" # This is the prefix, including the path and filename prefix, for the accuracy file output # This will be combined with TEST_NAMES to create the file ./results/peft_results_test_pubmedqa_inputs_preds_labels.jsonl OUTPUT_PREFIX="./results/peft_results" # This is the tensor parallel size (splitting tensors among GPUs horizontally) TP_SIZE=8 # This is the pipeline parallel size (splitting layers among GPUs vertically) PP_SIZE=1 python3 \ /opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_generate.py \ model.restore_from_path=${MODEL} \ model.peft.restore_from_path=${PATH_TO_TRAINED_MODEL} \ trainer.devices=8 \ model.tensor_model_parallel_size=${TP_SIZE} \ model.pipeline_model_parallel_size=${PP_SIZE} \ model.data.test_ds.file_names=${TEST_DS} \ model.data.test_ds.names=${TEST_NAMES} \ model.global_batch_size=32 \ model.micro_batch_size=4 \ model.data.test_ds.tokens_to_generate=20 \ inference.greedy=True \ model.data.test_ds.output_file_path_prefix=${OUTPUT_PREFIX} \ model.data.test_ds.write_predictions_to_file=True

After the evaluation finishes, you should see output similar to the following:

Copy
Copied!
            

──────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────── test_loss 0.3633725643157959 test_loss_pubmedqa 0.3633725643157959 val_loss 0.3633725643157959

Previous NeMo Framework SFT with Mistral-7B
Next NeMo Framework Foundation Model Pre-training
© | | | | | | |. Last updated on Jun 19, 2024.