Pretrain from Scratch, Continue Training from an Existing Checkpoint, and Fine-tune ESM-2nv on Custom Data
Contents
Pretrain from Scratch, Continue Training from an Existing Checkpoint, and Fine-tune ESM-2nv on Custom Data#
Demo Objectives#
Continue Training from a Model Checkpoint
Objective: Utilize ESM-2nv models for predicting antibody function with an additional prediction head…
Steps: Collect the data, and use existing downstream prediction head training scripts in BioNeMo for token-level classification.
Downstream Head Fine-tuning
Objective: Fine-tune ESM-2nv for predicting antibody function with an additional prediction head.
Steps: Collect the data, and use existing downstream prediction head training scripts in BioNeMo for token-level classification.
Full Parameter Fine-tuning on Antibody Sequences
Objective: Fine-tune an ESM-2nv foundation model and head on antibody sequences to enhance recognition of specific sequence patterns.
Steps: Prepare dataset, and fine-tune ESM-2nv.
Low-Rank Adaptation (LoRA) Fine-tuning
Objective: Apply LoRA to ESM-2nv for antibody sequences to improve efficiency and robustness.
Steps: Integrate LoRA adapters, and fine-tune adapters while freezing core weights.
For this purpose, we will use data available from the Therapeutic Data Commons for the prediction of amino acid binding in antibody sequences.
Setup#
Ensure that you have read through the Getting Started section, can run the BioNeMo Framework Docker container, and have configured the NGC Command Line Interface (CLI) within the container. It is assumed that this notebook is being executed from within the container.
%%capture --no-display --no-stderr cell_outputto suppress this output. Comment or delete this line in the cells below to restore full output.
Import and install all required packages#
%%capture --no-display --no-stderr cell_output
! pip install PyTDC
import os
import pandas as pd
import warnings
# Importing libraries to download and split datasets from the Therapeutic Data Commons https://tdcommons.ai/
from tdc.single_pred import Paratope
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
Home Directory#
bionemo_home = "/workspace/bionemo"
os.environ['BIONEMO_HOME'] = bionemo_home
os.chdir(bionemo_home)
Data Download and Preprocessing#
Dataset Overview: This dataset focuses on paratope prediction, which involves identifying the active binding regions within an antibody. It compiles sequences from SAbDab, encompassing both the heavy and light chains of the antibody.
Objective: The task involves classifying at the token level. For a given sequence of amino acids, the goal is to identify the specific amino acid tokens that are active in binding. In this context, X
represents the amino acid sequence, while Y
denotes the indices of active binding positions within X
.
Dataset Details: The dataset comprises sequences from 1,023 antibody chains.
# Specify dataset for download
dataset_name = 'SAbDab_Liberis'
print(f"Preparing raw {dataset_name} dataset from TD Commons.")
data = Paratope(name = dataset_name)
data_df = data.get_data()
splits = data.get_split()
data_df.head()
Found local copy...
Loading...
Done!
Preparing raw SAbDab_Liberis dataset from TD Commons.
Antibody_ID | Antibody | Y | |
---|---|---|---|
0 | 2hh0_H | LEQSGAELVKPGASVKLSCTASGFNIEDSYIHWVKQRPEQGLEWIG... | [49, 80, 81, 82, 101] |
1 | 1u8q_B | ITLKESGPPLVKPTQTLTLTCSFSGFSLSDFGVGVGWIRQPPGKAL... | [30, 31, 53, 83, 84, 85, 104, 105, 106, 107, 1... |
2 | 4ydl_H | EVRLVQSGNQVRKPGASVRISCEASGYKFIDHFIHWVRQVPGHGLE... | [52, 67, 68, 85, 86, 87, 106, 107] |
3 | 4ydl_L | EIVLTQSPGTLSLSPGETATLSCRTSQGILSNQLAWHQQRRGQPPR... | [30] |
4 | 1mhp_X | EVQLVESGGGLVQPGGSLRLSCAASGFTFSRYTMSWVRQAPGKGLE... | [52, 82, 83, 84, 103, 104] |
Each antibody sequence (Antibody
) is a string of amino acids, where the order and composition determine its function and specificity. Within each antibody sequence, specific positions (Y
) are crucial for its function and denote their belonging to the paratope of the antibody. Here, we define a function (encode_sequence
) to encode these specific positions in the antibody sequence by initializing a sequence with placeholders (N
for non-paratope positions) and marking the positions of interest with a label P
, denoting amino acids that belong to the paratope. The dataset is then divided into subsets (train
, val
, test
), and each subset undergoes the encoding strategy.
For ESM-2 to be trained on custom sequences, we will also need to create FASTA files.
base_data_dir = os.path.join(bionemo_home, 'data')
task_name = "paratope"
!mkdir -p {base_data_dir}/processed/{dataset_name}
SAbDab_dir = os.path.join(base_data_dir, 'processed', dataset_name)
def encode_sequence(row: pd.Series) -> str:
sequence = list('N' * len(row['Antibody'])) # Create a list of 'N's the same length as the sequence
# Check if row['Y'] is a string that needs to be evaluated
if isinstance(row['Y'], str):
positions = eval(row['Y']) # Convert string representation of list to actual list
else:
positions = row['Y'] # Assume row['Y'] is already in the correct format (e.g., a list)
for pos in positions:
adj_pos = pos - 1 # Adjust the position to 0-based indexing
sequence[adj_pos] = 'P' # Encode the position as 'P'
return ''.join(sequence) # Convert the list back to a string
# List of split names, assuming they are 'train', 'valid', and 'test'
# Update 'valid' to 'val' for the folder name
split_names = ['train', 'val', 'test']
for split_name in split_names:
# Adjust the key for accessing the validation split if necessary
split_key = 'valid' if split_name == 'val' else split_name
# Construct the file path
df = splits[split_key]
# Apply the function to each row
df['Encoded'] = df.apply(encode_sequence, axis=1)
# Adjust the directory structure for saving, now including the task_name
task_specific_dir = os.path.join(SAbDab_dir, task_name, split_name)
os.makedirs(task_specific_dir, exist_ok=True) # Ensure the directory exists
df = df[['Antibody', 'Encoded']] # Reorder the columns
# Save the modified DataFrame to the new path
df.to_csv(os.path.join(task_specific_dir, f"x000.csv"), index=False)
print(f"Encoded sequences saved as x000.csv in {task_specific_dir}")
# Save as FASTA
fasta_path = os.path.join(task_specific_dir, f"x000.fasta")
with open(fasta_path, 'w') as fasta_file:
for index, row in df.iterrows():
fasta_file.write(f">Sequence_{index}\n{row['Antibody']}\n")
print(f"Encoded sequences saved as x000.fasta in {task_specific_dir}")
Encoded sequences saved as x000.csv in /workspace/bionemo/data/processed/SAbDab_Liberis/paratope/train
Encoded sequences saved as x000.fasta in /workspace/bionemo/data/processed/SAbDab_Liberis/paratope/train
Encoded sequences saved as x000.csv in /workspace/bionemo/data/processed/SAbDab_Liberis/paratope/val
Encoded sequences saved as x000.fasta in /workspace/bionemo/data/processed/SAbDab_Liberis/paratope/val
Encoded sequences saved as x000.csv in /workspace/bionemo/data/processed/SAbDab_Liberis/paratope/test
Encoded sequences saved as x000.fasta in /workspace/bionemo/data/processed/SAbDab_Liberis/paratope/test
encoded_df = pd.read_csv(os.path.join(SAbDab_dir, task_name, 'train', 'x000.csv'))
encoded_df.head()
Antibody | Encoded | |
---|---|---|
0 | LEQSGAELVKPGASVKLSCTASGFNIEDSYIHWVKQRPEQGLEWIG... | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN... |
1 | ITLKESGPPLVKPTQTLTLTCSFSGFSLSDFGVGVGWIRQPPGKAL... | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNPPNNNNNNNNNNNNNNN... |
2 | EVRLVQSGNQVRKPGASVRISCEASGYKFIDHFIHWVRQVPGHGLE... | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN... |
3 | EVQLSESGGGFVKPGGSLRLSCEASGFTFNNYAMGWVRQAPGKGLE... | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN... |
4 | QVQLVQPGTAMKSLGSSLTITCRVSGDDLGSFHFGTYFMIWVRQAP... | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNPPPPPNNNNNNNNNNN... |
Download Model Checkpoints#
The following code will download the pretrained model esmn2nv_650M_converted.nemo
from the NGC registry.
In BioNeMo FW, there are numerous ESM models available, including ESM-1nv, ESM-2nv 8M with randomly initialized weights, ESM-2nv fine-tuned for secondary structure downstream prediction tasks with LoRA, ESM-2nv 650M, and ESM-2nv 3B. We also have a configuration file for training ESM-2nv 15B available at examples/protein/esm2nv/conf/pretrain_esm2_15B.yaml
, if needed.
For demo purposes, we have chosen to showcase the ESM-2nv 650M model. For more details on the ESM-1nv or ESM-2nv, consult the corresponding model cards. To find the model names and checkpoint names please refer to the artifacts_paths.yaml
file.
# Define the NGC CLI API KEY and ORG for the model download
# If these variables are not already set in the container, uncomment below
# to define and set with your API KEY and ORG
# api_key = <YOUR_API_KEY>
# ngc_cli_org = <YOUR_ORG>
# Update the environment variable
# os.environ['NGC_CLI_API_KEY'] = api_key
# os.environ['NGC_CLI_ORG'] = ngc_cli_org
# Set variables and paths for model and checkpoint
model_name = "esm2nv"
model_version = "esm2nv_650m"
actual_checkpoint_name = "esm2nv_650M_converted.nemo"
model_path = os.path.join(bionemo_home, 'models')
checkpoint_path = os.path.join(model_path, actual_checkpoint_name)
os.environ['MODEL_PATH'] = model_path
%%capture --no-display --no-stderr cell_output
if not os.path.exists(checkpoint_path):
!cd /workspace/bionemo && \
python download_artifacts.py --model_dir models --models {model_version}
else:
print(f"Model {model_version} already exists at {model_path}.")
Setting up paths to the data used for model training:
config_dir = os.path.join(bionemo_home, f'examples/protein/{model_name}/conf')
train_fasta = os.path.join(SAbDab_dir, f'{task_name}/train/x000.fasta')
val_fasta = os.path.join(SAbDab_dir, f'{task_name}/val/x000.fasta')
test_fasta = os.path.join(SAbDab_dir, f'{task_name}/test/x000.fasta')
paratope_dir = os.path.join(SAbDab_dir, 'paratope_custom_dataset')
! mkdir {paratope_dir}
mkdir: cannot create directory ‘/workspace/bionemo/data/processed/SAbDab_Liberis/paratope_custom_dataset’: File exists
Preprocessing and Pretraining from Scratch#
Performing preprocessing on the data to transform it into a format that can be used by the model.
%%capture --no-display --no-stderr cell_output
! cd {bionemo_home} && python examples/protein/esm2nv/pretrain.py \
--config-path={config_dir} \
--config-name=pretrain_esm2_650M \
++do_training=False \
++do_preprocessing=True \
++trainer.devices=1 \
++model.data.train.custom_pretraining_fasta_path={train_fasta} \
++model.data.val.custom_pretraining_fasta_path={val_fasta} \
++model.data.test.custom_pretraining_fasta_path={test_fasta} \
++model.data.dataset_path={paratope_dir} \
++model.data.train.dataset_path={paratope_dir} \
++exp_manager.create_wandb_logger=false
Pretrain from scratch#
This will take approximately 15 minutes on a A1000 GPU
%%capture --no-display --no-stderr cell_output
! cd {bionemo_home} && python examples/protein/esm2nv/pretrain.py \
--config-path={config_dir} \
--config-name=pretrain_esm2_650M \
name={model_name}_from_scratch_antibodies \
++do_training=True \
++trainer.devices=1 \
++trainer.max_steps=1 \
++trainer.val_check_interval=1 \
++model.data.train.custom_pretraining_fasta_path={train_fasta} \
++model.data.val.custom_pretraining_fasta_path={val_fasta} \
++model.data.test.custom_pretraining_fasta_path={test_fasta} \
++model.data.dataset_path={paratope_dir} \
++model.data.train.dataset_path={paratope_dir} \
++model.micro_batch_size=1 \
++exp_manager.create_wandb_logger=false
Continue Pretraining, Add a Downstream Head, Perform Full Parameter Fine-Tuning for ESM-2nv on Antibody Sequences#
1. Continue training from a model checkpoint#
In BioNeMo, you can easily continue training ESM-2nv on antibody sequences from a .nemo
checkpoint
To continue the pretraining of the foundation model, use the pretrain.py
script and set exp_manager.resume_if_exists=True
to load the model weights, maintain metadata from the previous run (e.g. max_steps) and it picks up from the learning rate at the end of the previous run from the existing esm2nv_650M_converted.nemo
checkpoint file. You can replace this file with another, but ensure to select the correct config file relative to the model of your choice.
%%capture --no-display --no-stderr cell_output
! cd {bionemo_home} && python /workspace/bionemo/examples/protein/esm2nv/pretrain.py \
--config-path={config_dir} \
--config-name=pretrain_esm2_650M \
name={model_name}_antibodies_continued \
do_training=True \
++trainer.devices=1 \
++trainer.max_steps=1 \
++trainer.val_check_interval=1 \
++model.data.train.custom_pretraining_fasta_path={train_fasta} \
++model.data.val.custom_pretraining_fasta_path={val_fasta} \
++model.data.test.custom_pretraining_fasta_path={test_fasta} \
++model.data.dataset_path={paratope_dir} \
++model.data.train.dataset_path={paratope_dir} \
++model.micro_batch_size=1 \
++exp_manager.create_wandb_logger=false \
++exp_manager.resume_if_exists=true
2. Downstream Head Fine-Tuning#
First, note that we are not using the pretrain.py
script but rather the downstream_flip.py
script. This script was originally created for downstream fine-tuning on the FLIP dataset. In addition to this Python script, we will use a yaml
file that already exists in BioNeMo for the token-level-classification
task, specifically named downstream_flip_sec_str
. We will override the configurations using Hydra. In particular, we do not want to perform training; instead, we want to add a prediction head, which in this case will be a Conv2D
head for token-level-classification
.
We will need to adjust the dwnstr_task_validation
configurations as well as the data used by the model
. In addition to setting the correct data paths, it is necessary to specify the number of classes we are predicting under target_sizes
as a list, as these will be used by the CNN. You can also provide mask columns; otherwise, set them to null
as a list. The target_column
should be the column in the dataframe where we have the labels, in this case, sequences labeled with N
and P
characters. Along with the labels, we need to specify the sequence column as well.
Importantly, we need to set the encoder path to esm2nv_650M_converted.nemo
. By default, the encoder_frozen
parameter is set to True
, meaning that the foundation model weights are fixed.
train_data = os.path.join(SAbDab_dir, f'{task_name}/train/x000.csv')
val_data = os.path.join(SAbDab_dir, f'{task_name}/val/x000.csv')
test_data = os.path.join(SAbDab_dir, f'{task_name}/test/x000.csv')
%%capture --no-display --no-stderr cell_output
! cd {bionemo_home} && python examples/protein/downstream/downstream_flip.py \
--config-path={config_dir} \
--config-name=downstream_flip_sec_str \
name={model_name}_with_head \
do_training=True \
do_testing=True \
++data.dataset_path={SAbDab_dir} \
++trainer.devices=1 \
++trainer.max_steps=1 \
++trainer.val_check_interval=1 \
++model.data.dataset.train={train_data} \
++model.data.dataset.val={val_data} \
++model.data.dataset.test={test_data} \
++model.data.target_column=['Encoded'] \
++model.data.sequence_column="Antibody" \
++model.data.target_sizes=[2] \
++model.data.mask_column=[null] \
++model.micro_batch_size=1 \
++model.data.task_name={task_name} \
++model.restore_encoder_path={checkpoint_path} \
++model.dwnstr_task_validation.dataset.dataset_path={SAbDab_dir} \
++model.data.preprocessed_data_path={SAbDab_dir} \
++exp_manager.create_wandb_logger=false
3. Full Parameter Fine-Tuning#
Fine-tuning the foundation model will require us to use the downstream_flip.py
script and set restore_encoder_path
to load the model weights from the existing checkpoint file. Also, ensure that the encoder weights are not frozen by setting model.encoder_frozen=False
.
%%capture --no-display --no-stderr cell_output
! cd {bionemo_home} && python examples/protein/downstream/downstream_flip.py \
--config-path={config_dir} \
--config-name=downstream_flip_sec_str \
name={model_name}_full_fine_tuning \
do_training=True \
do_testing=True \
++data.dataset_path={SAbDab_dir} \
++trainer.devices=1 \
++trainer.max_steps=1 \
++trainer.val_check_interval=1 \
++model.data.dataset.train={train_data} \
++model.data.dataset.val={val_data} \
++model.data.dataset.test={test_data} \
++model.data.target_column=['Encoded'] \
++model.data.sequence_column="Antibody" \
++model.data.target_sizes=[2] \
++model.data.mask_column=[null] \
++model.micro_batch_size=1 \
++model.data.task_name={task_name} \
++model.restore_encoder_path={checkpoint_path} \
++model.dwnstr_task_validation.dataset.dataset_path={SAbDab_dir} \
++model.data.preprocessed_data_path={SAbDab_dir} \
++exp_manager.create_wandb_logger=false \
++model.encoder_frozen=False \
++exp_manager.resume_if_exists=false
4. Low-Rank Adaptation (LoRA) fine-tuning#
A few notable changes in the downstream_sec_str_LORA.yaml
file are:
model.encoder_frozen
=False
. Set toFalse
when using PEFT.model.peft.enabled
=True
. Set toTrue
to enable PEFT.model.peft.lora_tuning.adapter_dim
: Allows setting different values for the rank used in matrix decomposition. This hyperparameter helps maximize performance on your data, as it determines the number of trainable parameters.model.peft.lora_tuning.layer_selection
: Selects the layers in which to add LoRA adapters. For example,[1,12]
will add LoRA to layer 1 (lowest) and layer 12.null
will apply adapters to all layers.
Following these instructions and reimplementing the ESM2nvLoRAModel
class in the bionemo/model/protein/esm1nv/esm1nv_model.py
script for ESM-1, you can perform LoRA.
For more details about LoRA please see this notebook.
%%capture --no-display --no-stderr cell_output
! cd {bionemo_home} && python examples/protein/downstream/downstream_flip.py \
--config-path={config_dir} \
--config-name=downstream_sec_str_LORA \
name={model_name}_LORA \
do_training=True \
do_testing=True \
++data.dataset_path={SAbDab_dir} \
++trainer.devices=1 \
++trainer.max_steps=1 \
++trainer.max_epochs=1 \
++trainer.val_check_interval=1 \
++model.encoder_frozen=False \
++model.data.task_name={task_name} \
++model.restore_encoder_path={checkpoint_path} \
++model.data.preprocessed_data_path={SAbDab_dir} \
++model.data.dataset.train={train_data} \
++model.data.dataset.val={val_data} \
++model.data.dataset.test={test_data} \
++model.data.target_column=['Encoded'] \
++model.data.sequence_column="Antibody" \
++model.data.target_sizes=[2] \
++model.data.mask_column=[null] \
++model.dwnstr_task_validation.dataset.target_column=['Encoded'] \
++model.dwnstr_task_validation.dataset.sequence_column="Antibody" \
++model.dwnstr_task_validation.dataset.target_sizes=[2] \
++model.dwnstr_task_validation.data_impl_kwargs.csv_mmap.data_col=1 \
++model.dwnstr_task_validation.dataset.mask_column=[null] \
++model.dwnstr_task_validation.dataset.dataset_path={SAbDab_dir} \
++exp_manager.create_wandb_logger=false \
++exp_manager.resume_if_exists=false
In this demo, we explored how to continue training ESM-2nv, add a downstream head, and perform full-parameter fine-tuning (both the foundation model and the head) for a token-level classification task on antibody sequences.