Pretrain from Scratch, Continue Training from an Existing Checkpoint, and Fine-tune ESM-2nv on Custom Data#

NOTE This notebook was tested on a single A1000 GPU and is compatible with BioNeMo Framework v1.6, v1.7 and v1.8 with an expected runtime of approximately 2 hours for ESM-2nv 650M model. This notebook is specific to the ESM-2nv model only.

Demo Objectives#

  1. Continue Training from a Model Checkpoint

    • Objective: Utilize ESM-2nv models for predicting antibody function with an additional prediction head…

    • Steps: Collect the data, and use existing downstream prediction head training scripts in BioNeMo for token-level classification.

  2. Downstream Head Fine-tuning

    • Objective: Fine-tune ESM-2nv for predicting antibody function with an additional prediction head.

    • Steps: Collect the data, and use existing downstream prediction head training scripts in BioNeMo for token-level classification.

  3. Full Parameter Fine-tuning on Antibody Sequences

    • Objective: Fine-tune an ESM-2nv foundation model and head on antibody sequences to enhance recognition of specific sequence patterns.

    • Steps: Prepare dataset, and fine-tune ESM-2nv.

  4. Low-Rank Adaptation (LoRA) Fine-tuning

    • Objective: Apply LoRA to ESM-2nv for antibody sequences to improve efficiency and robustness.

    • Steps: Integrate LoRA adapters, and fine-tune adapters while freezing core weights.

For this purpose, we will use data available from the Therapeutic Data Commons for the prediction of amino acid binding in antibody sequences.

Setup#

Ensure that you have read through the Getting Started section, can run the BioNeMo Framework Docker container, and have configured the NGC Command Line Interface (CLI) within the container. It is assumed that this notebook is being executed from within the container.

NOTE Some of the cells below generate long text output. We're using
%%capture --no-display --no-stderr cell_output
to suppress this output. Comment or delete this line in the cells below to restore full output.

Import and install all required packages#

%%capture --no-display --no-stderr cell_output
! pip install PyTDC

import os
import pandas as pd
import warnings

# Importing libraries to download and split datasets from the Therapeutic Data Commons https://tdcommons.ai/
from tdc.single_pred import Paratope

warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

Home Directory#

bionemo_home = "/workspace/bionemo"
os.environ['BIONEMO_HOME'] = bionemo_home
os.chdir(bionemo_home)

Data Download and Preprocessing#

Dataset Overview: This dataset focuses on paratope prediction, which involves identifying the active binding regions within an antibody. It compiles sequences from SAbDab, encompassing both the heavy and light chains of the antibody.

Objective: The task involves classifying at the token level. For a given sequence of amino acids, the goal is to identify the specific amino acid tokens that are active in binding. In this context, X represents the amino acid sequence, while Y denotes the indices of active binding positions within X.

Dataset Details: The dataset comprises sequences from 1,023 antibody chains.

# Specify dataset for download
dataset_name = 'SAbDab_Liberis'
print(f"Preparing raw {dataset_name} dataset from TD Commons.")
data = Paratope(name = dataset_name)
data_df = data.get_data()
splits = data.get_split()
data_df.head()
Found local copy...
Loading...
Done!
Preparing raw SAbDab_Liberis dataset from TD Commons.
Antibody_ID Antibody Y
0 2hh0_H LEQSGAELVKPGASVKLSCTASGFNIEDSYIHWVKQRPEQGLEWIG... [49, 80, 81, 82, 101]
1 1u8q_B ITLKESGPPLVKPTQTLTLTCSFSGFSLSDFGVGVGWIRQPPGKAL... [30, 31, 53, 83, 84, 85, 104, 105, 106, 107, 1...
2 4ydl_H EVRLVQSGNQVRKPGASVRISCEASGYKFIDHFIHWVRQVPGHGLE... [52, 67, 68, 85, 86, 87, 106, 107]
3 4ydl_L EIVLTQSPGTLSLSPGETATLSCRTSQGILSNQLAWHQQRRGQPPR... [30]
4 1mhp_X EVQLVESGGGLVQPGGSLRLSCAASGFTFSRYTMSWVRQAPGKGLE... [52, 82, 83, 84, 103, 104]

Each antibody sequence (Antibody) is a string of amino acids, where the order and composition determine its function and specificity. Within each antibody sequence, specific positions (Y) are crucial for its function and denote their belonging to the paratope of the antibody. Here, we define a function (encode_sequence) to encode these specific positions in the antibody sequence by initializing a sequence with placeholders (N for non-paratope positions) and marking the positions of interest with a label P, denoting amino acids that belong to the paratope. The dataset is then divided into subsets (train, val, test), and each subset undergoes the encoding strategy.

For ESM-2 to be trained on custom sequences, we will also need to create FASTA files.

base_data_dir = os.path.join(bionemo_home, 'data')
task_name = "paratope"
!mkdir -p {base_data_dir}/processed/{dataset_name}

SAbDab_dir = os.path.join(base_data_dir, 'processed', dataset_name)

def encode_sequence(row: pd.Series) -> str:
    sequence = list('N' * len(row['Antibody']))  # Create a list of 'N's the same length as the sequence
    # Check if row['Y'] is a string that needs to be evaluated
    if isinstance(row['Y'], str):
        positions = eval(row['Y'])  # Convert string representation of list to actual list
    else:
        positions = row['Y']  # Assume row['Y'] is already in the correct format (e.g., a list)
    for pos in positions:
        adj_pos = pos - 1  # Adjust the position to 0-based indexing
        sequence[adj_pos] = 'P'  # Encode the position as 'P'
    return ''.join(sequence)  # Convert the list back to a string

# List of split names, assuming they are 'train', 'valid', and 'test'
# Update 'valid' to 'val' for the folder name
split_names = ['train', 'val', 'test']

for split_name in split_names:
    # Adjust the key for accessing the validation split if necessary
    split_key = 'valid' if split_name == 'val' else split_name
    # Construct the file path
    df = splits[split_key]
    # Apply the function to each row
    df['Encoded'] = df.apply(encode_sequence, axis=1)

    # Adjust the directory structure for saving, now including the task_name
    task_specific_dir = os.path.join(SAbDab_dir, task_name, split_name)
    os.makedirs(task_specific_dir, exist_ok=True)  # Ensure the directory exists
    df = df[['Antibody', 'Encoded']]  # Reorder the columns
    # Save the modified DataFrame to the new path
    df.to_csv(os.path.join(task_specific_dir, f"x000.csv"), index=False)
    print(f"Encoded sequences saved as x000.csv in {task_specific_dir}")
    
    # Save as FASTA
    fasta_path = os.path.join(task_specific_dir, f"x000.fasta")
    with open(fasta_path, 'w') as fasta_file:
        for index, row in df.iterrows():
            fasta_file.write(f">Sequence_{index}\n{row['Antibody']}\n")
    print(f"Encoded sequences saved as x000.fasta in {task_specific_dir}")
Encoded sequences saved as x000.csv in /workspace/bionemo/data/processed/SAbDab_Liberis/paratope/train
Encoded sequences saved as x000.fasta in /workspace/bionemo/data/processed/SAbDab_Liberis/paratope/train
Encoded sequences saved as x000.csv in /workspace/bionemo/data/processed/SAbDab_Liberis/paratope/val
Encoded sequences saved as x000.fasta in /workspace/bionemo/data/processed/SAbDab_Liberis/paratope/val
Encoded sequences saved as x000.csv in /workspace/bionemo/data/processed/SAbDab_Liberis/paratope/test
Encoded sequences saved as x000.fasta in /workspace/bionemo/data/processed/SAbDab_Liberis/paratope/test
encoded_df = pd.read_csv(os.path.join(SAbDab_dir, task_name, 'train', 'x000.csv'))
encoded_df.head()
Antibody Encoded
0 LEQSGAELVKPGASVKLSCTASGFNIEDSYIHWVKQRPEQGLEWIG... NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN...
1 ITLKESGPPLVKPTQTLTLTCSFSGFSLSDFGVGVGWIRQPPGKAL... NNNNNNNNNNNNNNNNNNNNNNNNNNNNNPPNNNNNNNNNNNNNNN...
2 EVRLVQSGNQVRKPGASVRISCEASGYKFIDHFIHWVRQVPGHGLE... NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN...
3 EVQLSESGGGFVKPGGSLRLSCEASGFTFNNYAMGWVRQAPGKGLE... NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN...
4 QVQLVQPGTAMKSLGSSLTITCRVSGDDLGSFHFGTYFMIWVRQAP... NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNPPPPPNNNNNNNNNNN...

Download Model Checkpoints#

The following code will download the pretrained model esmn2nv_650M_converted.nemo from the NGC registry.

In BioNeMo FW, there are numerous ESM models available, including ESM-1nv, ESM-2nv 8M with randomly initialized weights, ESM-2nv fine-tuned for secondary structure downstream prediction tasks with LoRA, ESM-2nv 650M, and ESM-2nv 3B. We also have a configuration file for training ESM-2nv 15B available at examples/protein/esm2nv/conf/pretrain_esm2_15B.yaml, if needed.

For demo purposes, we have chosen to showcase the ESM-2nv 650M model. For more details on the ESM-1nv or ESM-2nv, consult the corresponding model cards. To find the model names and checkpoint names please refer to the artifacts_paths.yaml file.

# Define the NGC CLI API KEY and ORG for the model download
# If these variables are not already set in the container, uncomment below
# to define and set with your API KEY and ORG
# api_key = <YOUR_API_KEY>
# ngc_cli_org = <YOUR_ORG>
# Update the environment variable
# os.environ['NGC_CLI_API_KEY'] = api_key
# os.environ['NGC_CLI_ORG'] = ngc_cli_org

# Set variables and paths for model and checkpoint
model_name = "esm2nv" 
model_version = "esm2nv_650m" 
actual_checkpoint_name = "esm2nv_650M_converted.nemo"
model_path = os.path.join(bionemo_home, 'models')
checkpoint_path = os.path.join(model_path, actual_checkpoint_name)
os.environ['MODEL_PATH'] = model_path
%%capture --no-display --no-stderr cell_output
if not os.path.exists(checkpoint_path):
    !cd /workspace/bionemo && \
    python download_artifacts.py --model_dir models --models {model_version}
else:
    print(f"Model {model_version} already exists at {model_path}.")

Setting up paths to the data used for model training:

config_dir = os.path.join(bionemo_home, f'examples/protein/{model_name}/conf')
train_fasta = os.path.join(SAbDab_dir, f'{task_name}/train/x000.fasta')
val_fasta = os.path.join(SAbDab_dir, f'{task_name}/val/x000.fasta')
test_fasta = os.path.join(SAbDab_dir, f'{task_name}/test/x000.fasta')
paratope_dir = os.path.join(SAbDab_dir, 'paratope_custom_dataset')
! mkdir {paratope_dir}
mkdir: cannot create directory ‘/workspace/bionemo/data/processed/SAbDab_Liberis/paratope_custom_dataset’: File exists

Preprocessing and Pretraining from Scratch#

  • Performing preprocessing on the data to transform it into a format that can be used by the model.

%%capture --no-display --no-stderr cell_output
! cd {bionemo_home} && python examples/protein/esm2nv/pretrain.py \
  --config-path={config_dir} \
  --config-name=pretrain_esm2_650M \
  ++do_training=False \
  ++do_preprocessing=True \
  ++trainer.devices=1 \
  ++model.data.train.custom_pretraining_fasta_path={train_fasta} \
  ++model.data.val.custom_pretraining_fasta_path={val_fasta} \
  ++model.data.test.custom_pretraining_fasta_path={test_fasta} \
  ++model.data.dataset_path={paratope_dir} \
  ++model.data.train.dataset_path={paratope_dir} \
  ++exp_manager.create_wandb_logger=false

Pretrain from scratch#

This will take approximately 15 minutes on a A1000 GPU

%%capture --no-display --no-stderr cell_output
! cd {bionemo_home} && python examples/protein/esm2nv/pretrain.py \
    --config-path={config_dir} \
    --config-name=pretrain_esm2_650M \
    name={model_name}_from_scratch_antibodies \
    ++do_training=True \
    ++trainer.devices=1 \
    ++trainer.max_steps=1 \
    ++trainer.val_check_interval=1 \
    ++model.data.train.custom_pretraining_fasta_path={train_fasta} \
    ++model.data.val.custom_pretraining_fasta_path={val_fasta} \
    ++model.data.test.custom_pretraining_fasta_path={test_fasta} \
    ++model.data.dataset_path={paratope_dir} \
    ++model.data.train.dataset_path={paratope_dir} \
    ++model.micro_batch_size=1 \
    ++exp_manager.create_wandb_logger=false

Continue Pretraining, Add a Downstream Head, Perform Full Parameter Fine-Tuning for ESM-2nv on Antibody Sequences#

1. Continue training from a model checkpoint#

In BioNeMo, you can easily continue training ESM-2nv on antibody sequences from a .nemo checkpoint

IMPORTANT: For demonstration purposes, the `max_steps` and `val_check_interval` parameters in the fine-tuning process have been adjusted to lower values.

To continue the pretraining of the foundation model, use the pretrain.py script and set exp_manager.resume_if_exists=True to load the model weights, maintain metadata from the previous run (e.g. max_steps) and it picks up from the learning rate at the end of the previous run from the existing esm2nv_650M_converted.nemo checkpoint file. You can replace this file with another, but ensure to select the correct config file relative to the model of your choice.

%%capture --no-display --no-stderr cell_output
! cd {bionemo_home} && python /workspace/bionemo/examples/protein/esm2nv/pretrain.py \
    --config-path={config_dir} \
    --config-name=pretrain_esm2_650M \
    name={model_name}_antibodies_continued \
    do_training=True \
    ++trainer.devices=1 \
    ++trainer.max_steps=1 \
    ++trainer.val_check_interval=1 \
    ++model.data.train.custom_pretraining_fasta_path={train_fasta} \
    ++model.data.val.custom_pretraining_fasta_path={val_fasta} \
    ++model.data.test.custom_pretraining_fasta_path={test_fasta} \
    ++model.data.dataset_path={paratope_dir} \
    ++model.data.train.dataset_path={paratope_dir} \
    ++model.micro_batch_size=1 \
    ++exp_manager.create_wandb_logger=false \
    ++exp_manager.resume_if_exists=true

2. Downstream Head Fine-Tuning#

First, note that we are not using the pretrain.py script but rather the downstream_flip.py script. This script was originally created for downstream fine-tuning on the FLIP dataset. In addition to this Python script, we will use a yaml file that already exists in BioNeMo for the token-level-classification task, specifically named downstream_flip_sec_str. We will override the configurations using Hydra. In particular, we do not want to perform training; instead, we want to add a prediction head, which in this case will be a Conv2D head for token-level-classification.

We will need to adjust the dwnstr_task_validation configurations as well as the data used by the model. In addition to setting the correct data paths, it is necessary to specify the number of classes we are predicting under target_sizes as a list, as these will be used by the CNN. You can also provide mask columns; otherwise, set them to null as a list. The target_column should be the column in the dataframe where we have the labels, in this case, sequences labeled with N and P characters. Along with the labels, we need to specify the sequence column as well.

Importantly, we need to set the encoder path to esm2nv_650M_converted.nemo. By default, the encoder_frozen parameter is set to True, meaning that the foundation model weights are fixed.

train_data = os.path.join(SAbDab_dir, f'{task_name}/train/x000.csv')
val_data = os.path.join(SAbDab_dir, f'{task_name}/val/x000.csv')
test_data = os.path.join(SAbDab_dir, f'{task_name}/test/x000.csv')
%%capture --no-display --no-stderr cell_output
! cd {bionemo_home} && python examples/protein/downstream/downstream_flip.py \
    --config-path={config_dir} \
    --config-name=downstream_flip_sec_str \
    name={model_name}_with_head \
    do_training=True \
    do_testing=True \
    ++data.dataset_path={SAbDab_dir} \
    ++trainer.devices=1 \
    ++trainer.max_steps=1 \
    ++trainer.val_check_interval=1 \
    ++model.data.dataset.train={train_data} \
    ++model.data.dataset.val={val_data} \
    ++model.data.dataset.test={test_data} \
    ++model.data.target_column=['Encoded'] \
    ++model.data.sequence_column="Antibody" \
    ++model.data.target_sizes=[2] \
    ++model.data.mask_column=[null] \
    ++model.micro_batch_size=1 \
    ++model.data.task_name={task_name} \
    ++model.restore_encoder_path={checkpoint_path} \
    ++model.dwnstr_task_validation.dataset.dataset_path={SAbDab_dir} \
    ++model.data.preprocessed_data_path={SAbDab_dir} \
    ++exp_manager.create_wandb_logger=false

3. Full Parameter Fine-Tuning#

Fine-tuning the foundation model will require us to use the downstream_flip.py script and set restore_encoder_path to load the model weights from the existing checkpoint file. Also, ensure that the encoder weights are not frozen by setting model.encoder_frozen=False.

%%capture --no-display --no-stderr cell_output
! cd {bionemo_home} && python examples/protein/downstream/downstream_flip.py \
    --config-path={config_dir} \
    --config-name=downstream_flip_sec_str \
    name={model_name}_full_fine_tuning \
    do_training=True \
    do_testing=True \
    ++data.dataset_path={SAbDab_dir} \
    ++trainer.devices=1 \
    ++trainer.max_steps=1 \
    ++trainer.val_check_interval=1 \
    ++model.data.dataset.train={train_data} \
    ++model.data.dataset.val={val_data} \
    ++model.data.dataset.test={test_data} \
    ++model.data.target_column=['Encoded'] \
    ++model.data.sequence_column="Antibody" \
    ++model.data.target_sizes=[2] \
    ++model.data.mask_column=[null] \
    ++model.micro_batch_size=1 \
    ++model.data.task_name={task_name} \
    ++model.restore_encoder_path={checkpoint_path} \
    ++model.dwnstr_task_validation.dataset.dataset_path={SAbDab_dir} \
    ++model.data.preprocessed_data_path={SAbDab_dir} \
    ++exp_manager.create_wandb_logger=false \
    ++model.encoder_frozen=False \
    ++exp_manager.resume_if_exists=false

4. Low-Rank Adaptation (LoRA) fine-tuning#

A few notable changes in the downstream_sec_str_LORA.yaml file are:

  • model.encoder_frozen= False. Set to False when using PEFT.

  • model.peft.enabled= True. Set to True to enable PEFT.

  • model.peft.lora_tuning.adapter_dim: Allows setting different values for the rank used in matrix decomposition. This hyperparameter helps maximize performance on your data, as it determines the number of trainable parameters.

  • model.peft.lora_tuning.layer_selection: Selects the layers in which to add LoRA adapters. For example, [1,12] will add LoRA to layer 1 (lowest) and layer 12. null will apply adapters to all layers.

NOTE LoRA is currently not supported for esm-1nv

Following these instructions and reimplementing the ESM2nvLoRAModel class in the bionemo/model/protein/esm1nv/esm1nv_model.py script for ESM-1, you can perform LoRA.

For more details about LoRA please see this notebook.

%%capture --no-display --no-stderr cell_output
! cd {bionemo_home} && python examples/protein/downstream/downstream_flip.py \
    --config-path={config_dir} \
    --config-name=downstream_sec_str_LORA \
    name={model_name}_LORA \
    do_training=True \
    do_testing=True \
    ++data.dataset_path={SAbDab_dir} \
    ++trainer.devices=1 \
    ++trainer.max_steps=1 \
    ++trainer.max_epochs=1 \
    ++trainer.val_check_interval=1 \
    ++model.encoder_frozen=False \
    ++model.data.task_name={task_name} \
    ++model.restore_encoder_path={checkpoint_path} \
    ++model.data.preprocessed_data_path={SAbDab_dir} \
    ++model.data.dataset.train={train_data} \
    ++model.data.dataset.val={val_data} \
    ++model.data.dataset.test={test_data} \
    ++model.data.target_column=['Encoded'] \
    ++model.data.sequence_column="Antibody" \
    ++model.data.target_sizes=[2] \
    ++model.data.mask_column=[null] \
    ++model.dwnstr_task_validation.dataset.target_column=['Encoded'] \
    ++model.dwnstr_task_validation.dataset.sequence_column="Antibody" \
    ++model.dwnstr_task_validation.dataset.target_sizes=[2] \
    ++model.dwnstr_task_validation.data_impl_kwargs.csv_mmap.data_col=1 \
    ++model.dwnstr_task_validation.dataset.mask_column=[null] \
    ++model.dwnstr_task_validation.dataset.dataset_path={SAbDab_dir} \
    ++exp_manager.create_wandb_logger=false \
    ++exp_manager.resume_if_exists=false

In this demo, we explored how to continue training ESM-2nv, add a downstream head, and perform full-parameter fine-tuning (both the foundation model and the head) for a token-level classification task on antibody sequences.