ESM-2 Fine-tuning¶
The ESM-2 model is a transformer-based protein language model that has achieved state-of-the-art results in various protein-related tasks. When fine-tuning ESM2, the task-head plays a crucial role. A task head refers to the additional layer or set of layers added on top of a pre-trained model, like the ESM-2 transformer-based protein language model, to adapt it for a specific downstream task. As a part of transfer learning, a pre-trained model is often utilized to learn generic features from a large-scale dataset. However, these features might not be directly applicable to the specific task at hand. By incorporating a task head, which consists of learnable parameters, the model can adapt and specialize to the target task. The task head serves as a flexible and adaptable component that learns task-specific representations by leveraging the pre-trained features as a foundation. Through fine-tuning, the task head enables the model to learn and extract task-specific patterns, improving performance and addressing the nuances of the downstream task. It acts as a critical bridge between the pre-trained model and the specific task, enabling efficient and effective transfer of knowledge.
The utilities described in this tutorial are available in:
bionemo.esm2.model.finetune
In the second part of the tutorial, we will cover loading a pre-trained model, fine-tuning it sequence-level regression/classification and token-level classification, and using the fine-tuned models for inference. For instructions on pre-training the ESM-2 model, please refer to the ESM-2 Pretraining tutorial.
Building a Regression Fine-tune Module¶
We need to define some key classes to successfully build a fine-tuning module in BioNeMo framework:
- Loss Reduction Class - To compute the supervised fine-tuning loss.
- Fine-Tuned Model Head - Downstream task head model.
- Fine-Tuned Model - Model that combines ESM-2 with the task head model.
- Fine-Tuning Config - Configures the fine-tuning model and loss to use in the training and inference framework.
- Dataset - Training and inference datasets for ESM2 fine-tuning.
1 - Loss Reduction Class¶
A class for calculating the supervised loss of the fine-tune model from targets. We inherit from Megatron Bert Masked Language Model Loss (BERTMLMLossWithReduction
) and override the forward()
pass to compute the sum of squared errors of the regression head within a micro-batch. The reduce()
method of (BERTMLMLossWithReduction
) is used for computing the average over the micro-batches, i.e. MSE loss.
class RegressorLossReduction(BERTMLMLossWithReduction):
def forward(
self, batch: Dict[str, torch.Tensor], forward_out: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
regression_output = forward_out["regression_output"]
targets = batch["labels"].to(dtype=regression_output.dtype) # [b, 1]
num_valid_tokens = torch.tensor(targets.numel(), dtype=torch.int, device=targets.device)
loss_sum = ((regression_output - targets) ** 2).sum() # [b, 1]
loss_sum_and_ub_size = torch.cat([loss_sum.clone().detach().view(1), num_valid_tokens.view(1)])
return loss_sum, num_valid_tokens, {"loss_sum_and_ub_size": loss_sum_and_ub_size}
2 - Fine-Tuned Model Head¶
An MLP class for sequence-level regression. This class inherits MegatronModule
and uses the fine-tune config (TransformerConfig
) to configure the regression head for the fine-tuned ESM-2 model.
class MegatronMLPHead(MegatronModule):
def __init__(self, config: TransformerConfig):
super().__init__(config)
layer_sizes = [config.hidden_size, 256, 1]
self.linear_layers = torch.nn.ModuleList(
[torch.nn.Linear(i, o) for i, o in zip(layer_sizes[:-1], layer_sizes[1:])]
)
self.act = torch.nn.ReLU()
self.dropout = torch.nn.Dropout(p=config.ft_dropout)
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
...
3 - Fine-Tuned Model¶
A fine-tuned ESM-2 model class for token classification tasks. This class inherits from the ESM2Model
class and adds the custom regression head MegatronMLPHead
the we created in the previous step. Optionally one can freeze all or parts of the encoder by parsing through the model parameters in the model constructor.
class ESM2FineTuneSeqModel(ESM2Model):
def __init__(self, config, *args, post_process: bool = True, include_embeddings: bool = False, **kwargs):
super().__init__(config, *args, post_process=post_process, include_embeddings=True, **kwargs)
# freeze encoder parameters
if config.encoder_frozen:
for _, param in self.named_parameters():
param.requires_grad = False
if post_process:
self.regression_head = MegatronMLPHead(config)
def forward(self, *args, **kwargs,):
output = super().forward(*args, **kwargs)
...
output["regression_output"] = self.regression_head(embeddings)
return output
4 - Fine-Tuning Config¶
A dataclass
that configures the fine-tuned ESM-2 model. In this example ESM2FineTuneSeqConfig
inherits from ESM2GenericConfig
and adds custom arguments to setup the fine-tuned model. The configure_model()
method of this dataclass
is called within the Lightning
module to call the model constructor with the dataclass
arguments.
The common arguments among different fine-tuning tasks are
model_cls
: The fine-tune model class defined in previous step (ESM2FineTuneSeqModel
)initial_ckpt_path
: BioNeMo 2.0 ESM-2 pre-trained checkpointinitial_ckpt_skip_keys_with_these_prefixes
: skips keys when loading parameters from a checkpoint. For example, we should not look for the keyregression_head
when initializing aESM2FineTuneSeqModel
with the encoder weights of the pretrained model checkpoint (ESM2Model
).get_loss_reduction_class()
: Implements selection of the appropriateMegatronLossReduction
class that we defined in the first step of this tutorial.
@dataclass
class ESM2FineTuneSeqConfig(
ESM2GenericConfig[ESM2FineTuneSeqModel, RegressorLossReduction], iom.IOMixinWithGettersSetters
):
model_cls: Type[ESM2FineTuneSeqModel] = ESM2FineTuneSeqModel
# The following checkpoint path is for nemo2 checkpoints. Config parameters not present in
# self.override_parent_fields will be loaded from the checkpoint and override those values here.
initial_ckpt_path: str | None = None
# typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint
# that has this new head and want to keep using these weights, please drop this next line or set to []
initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["regression_head"])
encoder_frozen: bool = True # freeze encoder parameters
ft_dropout: float = 0.25 # MLP layer dropout
def get_loss_reduction_class(self) -> Type[MegatronLossReduction]:
return RegressorLossReduction
5 - Dataset¶
We will use a sample dataset for demonstration purposes. Create a dataset class by extending bionemo.esm2.model.finetune.dataset.InMemoryProteinDataset
. The InMemoryProteinDataset
has a classmethod
(from_csv
) that reads data from a CSV file that has sequences
and optionally labels
and labels_mask
columns. The labels_mask
column option is only available for InMemoryPerTokenValueDataset
and is used to specify which input sequence positions to use and which to ignore during training. It is important to override the transform_label()
method that returns a torch.Tensor
containing the label in correct format. As an example we can use this method to add custom tokenization if label
is a string.
The custom dataset class will be appropriate (found in bionemo.esm2.model.finetune.dataset.InMemorySingleValueDataset
) as it facilitates predicting on a single value. An excerpt from the class is shown below. This example dataset has a class method from_csv()
that expects a data_path
to a CSV file that has sequences
, and labels
columns.
class InMemorySingleValueDataset(InMemoryProteinDataset):
def __init__(
self,
labels: pd.Series,
task_type: str = "regression",
tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
seed: int = np.random.SeedSequence().entropy,
):
super().__init__(sequences, labels, task_type, tokenizer, seed)
def transform_label(self, label: float) -> Tensor:
return torch.tensor([label], dtype=torch.float)
The transform_label
method allows for custom transformation of raw labels by casting or tokenization and need to be adjusted based on the data. Here we use this method to create a float
tensor of the regression value.
DataModule¶
To coordinate the creation of training, validation and testing datasets from your data, we need to use a datamodule
class. To do this we can directly use or extend the ESM2FineTuneDataModule
class (located at bionemo.esm2.model.finetune.datamodule.ESM2FineTuneDataModule
) which defines helpful abstract methods that use your dataset class.
dataset = InMemorySingleValueDataset.from_csv(data_path)
data_module = ESM2FineTuneDataModule(
train_dataset=dataset,
valid_dataset=dataset
micro_batch_size=4, # size of a batch to be processed in a device
global_batch_size=8, # size of batch across all devices. Should be multiple of micro_batch_size
)
In the next part of this tutorial we will prepare the input needed to run sequence-level regression/classification and token-level classification fine-tuning examples.
Setup and Assumptions¶
All commands should be executed inside the BioNeMo docker container, which has all ESM-2 dependencies pre-installed. For more information on how to build or pull the BioNeMo2 container, refer to the Initialization Guide.
%%capture --no-display --no-stderr cell_outputto suppress this output. Comment or delete this line in the cells below to restore full output.
Import Required Libraries¶
%%capture --no-display --no-stderr cell_output
import os
import shutil
import warnings
import pandas as pd
warnings.filterwarnings("ignore")
warnings.simplefilter("ignore")
Work Directory¶
Set the work directory to store data and results:
cleanup : bool = True
cleanup: bool = True
ERROR! Session/line number was not unique in database. History logging moved to new session 2
work_dir = "/workspace/bionemo2/esm2_finetune_tutorial"
if cleanup and os.path.exists(work_dir):
shutil.rmtree(work_dir)
if not os.path.exists(work_dir):
os.makedirs(work_dir)
print(f"Directory '{work_dir}' created.")
else:
print(f"Directory '{work_dir}' already exists.")
Directory '/workspace/bionemo2/esm2_finetune_tutorial' created.
Download Pre-trained Model Checkpoints¶
The following code will download the internally pre-trained model, esm2/8m:2.0
, from the NGC registry. Please refer to ESM-2 Model Overview for a list of available checkpoints.
from bionemo.core.data.load import load
pretrain_checkpoint_path = load("esm2/8m:2.0")
print(pretrain_checkpoint_path)
/home/ubuntu/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz.untar
The above example is downloading an internally trained 8M ESM-2 model. The pre-trained checkpoints can be downloaded from NGC resources using either the following bash command or the load
function in bionemo.core.data.load
as shown above.
download_bionemo_data esm2/650m:2.0
which returns the checkpoint path (e.g. .../.cache/bionemo/975d29ee980fcb08c97401bbdfdcf8ce-esm2_650M_nemo2.tar.gz.untar
)
Fine-tuning¶
We can take advantage of the ESM2 fine-tuning script in bionemo.esm2.scripts.finetune_esm2
or use the finetune_esm2
executable the fine-tuning process given:
- Pre-trained checkpoint of ESM2
- Finetune config class name that configures the finetune model and loss reduction
- Path to train and validation CSV data files
- Dataset class name
To get the full list of arguments to tune a finetuning run use:
finetune_esm2 --help
For a detailed description of training loop and the arguments please refer to the ESM-2 Pretraining tutorial.
Scaled LR for fine-tune head parameters¶
We can assign a different LR for specific layers (e.g. task head) during fine-tuning by making it possible to specify the name of the target layer as well as the LR multiplier.
--lr-multiplier
: is a float that scales--lr
--sclae-lr-layer
: is the name of the layers for which we scale the LR
Due to Megatron limitations, the log produced by the training run iterates on steps/iterations and not epochs. Therefore, Training epoch
counter stays at value zero while iteration
and global_step
increase during the course of training (example in the following).
Training epoch 0, iteration| ... | global_step: | reduced_train_loss: ... | val_loss: ...
to achieve the same epoch-based effect while training, please choose the number of training steps (num_steps
) so that:
num_steps * global_batch_size = len(dataset) * desired_num_epochs
Sequence-level Regression¶
For the purposes of this demo, we'll assume dataset consists of small set of protein sequences with a target value of len(sequence) / 100.0
as their labels.
artificial_sequence_data = [
"TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
"LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
"GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN",
"DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
"KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
"LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
"LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
"LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
"ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
"SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT",
]
regression_data = [(seq, len(seq) / 100.0) for seq in artificial_sequence_data]
# Create a DataFrame
df = pd.DataFrame(regression_data, columns=["sequences", "labels"])
# Save the DataFrame to a CSV file
regression_data_path = os.path.join(work_dir, "regression_data.csv")
df.to_csv(regression_data_path, index=False)
We will use the sequence-level fine-tune model config ESM2FineTuneSeqConfig
and single-value dataset InMemorySingleValueDataset
and set the task-type to regression
. In addition to model and dataset configuration we can define the MLP task head and specify the number of hidden parameters (mlp-hidden-size
), output layer size (mlp-target-size
) and dropout (mlp-ft-dropout
) in the following CLI call.
%%capture --no-display cell_output
! finetune_esm2 \
--restore-from-checkpoint-path {pretrain_checkpoint_path} \
--train-data-path {regression_data_path} \
--valid-data-path {regression_data_path} \
--config-class ESM2FineTuneSeqConfig \
--dataset-class InMemorySingleValueDataset \
--task-type "regression" \
--mlp-ft-dropout 0.25 \
--mlp-hidden-size 256 \
--mlp-target-size 1 \
--experiment-name "sequence-level-regression" \
--num-steps 50 \
--num-gpus 1 \
--limit-val-batches 10 \
--val-check-interval 10 \
--log-every-n-steps 10 \
--encoder-frozen \
--lr 1e-5 \
--lr-multiplier 1e2 \
--scale-lr-layer "regression_head" \
--result-dir {work_dir} \
--micro-batch-size 4 \
--label-column "labels" \
--num-gpus 1 \
--precision "bf16-mixed"
The previous cell executes the finetuning and saves the checkpoints at the end of the run.
To avoid long text output from the previous cell, the log is captured and stored into the cell_output
variable. To visualize the log file uncomment and execute the next cell:
# print(cell_output)
# check that checkpoint was correctly created for step 50
regression_output_ckpt = f"{work_dir}/sequence-level-regression/checkpoints/checkpoint-step=49-consumed_samples=200.0"
! ls -l {regression_output_ckpt}
total 8 drwxr-xr-x 2 ubuntu ubuntu 4096 Jun 26 22:31 context drwxr-xr-x 2 ubuntu ubuntu 4096 Jun 26 22:31 weights
We can now use the checkpoint stored in the previous step to run inference and provide that to the --checkpoint-path
argument of infer_esm2
executable.
The input --data-path
for inference is a CSV file with sequences
column. It is also required to provide the appropriate --config-class
name to load the model from the checkpoint. For a detailed description of inference arguments please refer to the ESM-2 Inference tutorial.
# Create a DataFrame for testing
df = pd.DataFrame(artificial_sequence_data, columns=["sequences"])
# Save the DataFrame to a CSV file
test_regression_data_path = os.path.join(work_dir, "test_regression.csv")
df.to_csv(test_regression_data_path, index=False)
regression_results_path = f"{work_dir}/sequence-level-regression/infer/"
%%capture --no-display cell_output
! infer_esm2 --checkpoint-path {regression_output_ckpt} \
--config-class ESM2FineTuneSeqConfig \
--data-path {test_regression_data_path} \
--results-path {regression_results_path} \
--micro-batch-size 3 \
--num-gpus 1 \
--precision "bf16-mixed" \
--include-embeddings \
--include-input-ids
The inference results are written into a .pt
file which can be loaded using PyTorch library:
import numpy as np
import torch
results = torch.load(f"{regression_results_path}/predictions__rank_0.pt")
for key, val in results.items():
if val is not None:
print(f"{key}\t{val.shape}")
regression_prediction = results["regression_output"].detach().cpu().float().numpy()
regression_target = np.asarray([x for _, x in regression_data])
test_loss = np.mean((regression_target - regression_prediction) ** 2)
assert test_loss < 0.025
print(f"test loss is {test_loss}")
input_ids torch.Size([10, 1024]) embeddings torch.Size([10, 320]) regression_output torch.Size([10, 1]) test loss is 0.019693604125976566
Sequence-level Classification¶
Similarly for a sequence-level classification task we can create a dataset by labeling our sequences with arbitrary class names and take advantage of Label2IDTokenizer
in the transform_label()
method of InMemorySingleValueDataset
.
class_labels = [
"E_class",
"C_class",
"H_class",
"H_class",
"C_class",
"H_class",
"H_class",
"C_class",
"H_class",
"C_class",
]
sequence_classification_data = [(seq, label) for seq, label in zip(artificial_sequence_data, class_labels)]
# Create a DataFrame
df = pd.DataFrame(sequence_classification_data, columns=["sequences", "labels"])
# Save the DataFrame to a CSV file
sequence_classification_data_path = os.path.join(work_dir, "sequence_classification_data.csv")
df.to_csv(sequence_classification_data_path, index=False)
Since this task is also a sequence-level fine-tuning, we will use ESM2FineTuneSeqConfig
and single-value dataset InMemorySingleValueDataset
but set the task-type to classification
. For this classification task the MLP output layer size (mlp-target-size
) should be set to number of classes in the dataset (3 in this example).
%%capture --no-display cell_output
! finetune_esm2 \
--restore-from-checkpoint-path {pretrain_checkpoint_path} \
--train-data-path {sequence_classification_data_path} \
--valid-data-path {sequence_classification_data_path} \
--config-class ESM2FineTuneSeqConfig \
--dataset-class InMemorySingleValueDataset \
--task-type "classification" \
--mlp-ft-dropout 0.25 \
--mlp-hidden-size 256 \
--mlp-target-size 3 \
--experiment-name "sequence-level-classification" \
--num-steps 50 \
--num-gpus 1 \
--val-check-interval 10 \
--log-every-n-steps 10 \
--limit-val-batches 10 \
--encoder-frozen \
--lr 1e-5 \
--lr-multiplier 1e2 \
--scale-lr-layer "classification_head" \
--result-dir {work_dir} \
--label-column "labels" \
--micro-batch-size 4 \
--num-gpus 1 \
--precision "bf16-mixed"
# print(cell_output)
# check that checkpoint was correctly created for step 50
sequence_classification_output_ckpt = (
f"{work_dir}/sequence-level-classification/checkpoints/checkpoint-step=49-consumed_samples=200.0"
)
! ls -l {sequence_classification_output_ckpt}
total 8 drwxr-xr-x 2 ubuntu ubuntu 4096 Jun 26 22:32 context drwxr-xr-x 2 ubuntu ubuntu 4096 Jun 26 22:32 weights
# Create a test DataFrame
df = pd.DataFrame(artificial_sequence_data, columns=["sequences"])
# Save the DataFrame to a CSV file
test_sequence_classification_data_path = os.path.join(work_dir, "test_sequence_classification.csv")
df.to_csv(test_sequence_classification_data_path, index=False)
sequence_classification_results_path = f"{work_dir}/sequence-level-classification/infer/"
%%capture --no-display cell_output
! infer_esm2 --checkpoint-path {sequence_classification_output_ckpt} \
--config-class ESM2FineTuneSeqConfig \
--data-path {test_sequence_classification_data_path} \
--results-path {sequence_classification_results_path} \
--micro-batch-size 3 \
--num-gpus 1 \
--precision "bf16-mixed" \
--include-embeddings \
--include-input-ids
import torch
results = torch.load(f"{sequence_classification_results_path}/predictions__rank_0.pt")
for key, val in results.items():
if val is not None:
print(f"{key}\t{val.shape}")
sequence_classification_prediction = results["classification_output"].detach().cpu().float().numpy()
sequence_classification_target = np.asarray([list(dict.fromkeys(class_labels)).index(item) for item in class_labels])
predicted_classes = np.argmax(sequence_classification_prediction, axis=1)
test_acc = np.mean(predicted_classes == sequence_classification_target)
print(f"sequence_classification_target {sequence_classification_target}")
print(f"predicted_classes {predicted_classes}")
assert test_acc > 0.55
print(f"test acc is {test_loss}")
input_ids torch.Size([10, 1024]) embeddings torch.Size([10, 320]) classification_output torch.Size([10, 3]) sequence_classification_target [0 1 2 2 1 2 2 1 2 1] predicted_classes [0 0 2 2 0 2 2 0 2 0] test acc is 0.019693604125976566
Toke-level Classification data¶
For this task we assign secondary structure label to each token in the sequence:
secondary_structure_labels = [
"EEEECCCCCHHHHHHHHHHHHHHHCCCEEEEEECCCHHHHHHHHHCCCCCCCCCEEE",
"CCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC",
"HHHHHCCCCCHHHHHHHHHHHHHHCCCHHHHHHHHHH",
"HHHHHHHHHHCCCHHHHHCCCCCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC",
"CHHHHHHHHHHHHHHHCCCEEEEEECCCHHHHHHHHHCCCCCCCCCEEE",
"HHHHHHHHHHHHHCHHHHHHHHHHHHCCCEECCCEEEECCEEEEECC",
"HHHHHCCCHHHHHCCCCCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC",
"CCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC",
"HHHHHCHHHHHHHHHHHHCCCEECCCEEEECCEEEEECC",
"CCCCCCCCCCCCCCCCCCCCCCCCCCEEECCCCEEECHHHHHHHHHCCCCCCCCEEECCCCCC",
]
token_classification_data = [
(seq, label) for (seq, label) in zip(artificial_sequence_data, secondary_structure_labels)
]
# Create a DataFrame
df = pd.DataFrame(token_classification_data, columns=["sequences", "labels"])
# Save the DataFrame to a CSV file
token_classification_data_path = os.path.join(work_dir, "token_classification_data.csv")
df.to_csv(token_classification_data_path, index=False)
%%capture --no-display cell_output
! finetune_esm2 \
--restore-from-checkpoint-path {pretrain_checkpoint_path} \
--train-data-path {token_classification_data_path} \
--valid-data-path {token_classification_data_path} \
--config-class ESM2FineTuneTokenConfig \
--dataset-class InMemoryPerTokenValueDataset \
--task-type "classification" \
--cnn-dropout 0.25 \
--cnn-hidden-size 32 \
--cnn-num-classes 3 \
--experiment-name "token-level-classification" \
--num-steps 50 \
--num-gpus 1 \
--val-check-interval 10 \
--log-every-n-steps 10 \
--label-column "labels" \
--limit-val-batches 10 \
--encoder-frozen \
--lr 1e-5 \
--lr-multiplier 1e2 \
--scale-lr-layer "classification_head" \
--result-dir {work_dir} \
--micro-batch-size 4 \
--num-gpus 1 \
--precision "bf16-mixed"
# print(cell_output)
# check that checkpoint was correctly created for step 50
token_classification_output_ckpt = (
f"{work_dir}/token-level-classification/checkpoints/checkpoint-step=49-consumed_samples=200.0"
)
! ls -l {token_classification_output_ckpt}
total 8 drwxr-xr-x 2 ubuntu ubuntu 4096 Jun 26 22:33 context drwxr-xr-x 2 ubuntu ubuntu 4096 Jun 26 22:33 weights
We can now use the checkpoint stored in the previous step to run inference and provide that to the --checkpoint-path
argument of infer_esm2
executable.
The input --data-path
for inference is a CSV file with sequences
column. It is also required to provide the appropriate --config-class
name to load the model from the checkpoint. For a detailed description of inference arguments please refer to the ESM-2 Inference tutorial.
# Create a DataFrame
df = pd.DataFrame(artificial_sequence_data, columns=["sequences"])
# Save the DataFrame to a CSV file
test_token_classification_data_path = os.path.join(work_dir, "test_token_classification.csv")
df.to_csv(test_token_classification_data_path, index=False)
token_classification_results_path = f"{work_dir}/token-level-classification/infer/"
%%capture --no-display --no-stderr cell_output
! infer_esm2 --checkpoint-path {token_classification_output_ckpt} \
--config-class ESM2FineTuneTokenConfig \
--data-path {test_token_classification_data_path} \
--results-path {token_classification_results_path} \
--micro-batch-size 3 \
--num-gpus 1 \
--precision "bf16-mixed" \
--include-embeddings \
--include-hiddens \
--include-input-ids
# print(cell_output)
The inference results are written into a .pt
file which can be loaded using PyTorch library:
import torch
results = torch.load(f"{token_classification_results_path}/predictions__rank_0.pt")
for key, val in results.items():
if val is not None:
print(f"{key}\t{val.shape}")
hidden_states torch.Size([10, 1024, 320]) input_ids torch.Size([10, 1024]) embeddings torch.Size([10, 320]) classification_output torch.Size([10, 1024, 3])
We can use the label tokenizer to convert the classification output to class names. Note that for demonstration purposes we are using a small dataset of artificial sequences in this example. You may experience over-fitting and observe no change in the validation metrics. This amount of data and the short training run does not result in accurate predictions.
from bionemo.esm2.data.tokenizer import get_tokenizer
tokenizer = get_tokenizer()
tokens = tokenizer.all_tokens
aa_tokens = ["L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K", "Q", "N", "F", "Y", "M", "H", "W", "C"]
aa_indices = [i for i, token in enumerate(tokens) if token in aa_tokens]
extra_indices = [i for i, token in enumerate(tokens) if token not in aa_tokens]
input_ids = results["input_ids"] # b, s
# mask where non-amino acid tokens are True
mask = ~torch.isin(input_ids, torch.tensor(extra_indices))
from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer
label_tokenizer = Label2IDTokenizer()
label_tokenizer = label_tokenizer.build_vocab(pd.Series(secondary_structure_labels).sort_values(inplace=False).values)
output_ids = torch.argmax(results["classification_output"], dim=-1)
print("Predicted Secondary Structures:")
for i in range(output_ids.shape[0]):
ss_ids = output_ids[i][mask[i]]
print(label_tokenizer.ids_to_text(ss_ids.tolist()))
Predicted Secondary Structures: HEEECCCCCHHHHHHHHHHHHHHHCCCEEEEEECCCHHHHHHHHHCCCCCCCCCEEE CCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC CHHHHCCCCCHHHHHHHHHHHHHHCCCHHHHHHHHHC HHHHHHHHHHCCCHHHHHCCCCCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC CHHHHHHHHHHHHHHHCCCEEEEEECCCHHHHHHHHHCCCCCCCCCEEE HHHHHHHHHHHHHCHHHHHHHHHHHHCCCEECCCEEEECCEEEEECC HHHHHCCCHHHHHCCCCCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC CCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC HHHHHCHHHHHHHHHHHHCCCEECCCEEEECCEEEEECC CCCCCCCCCCCCCCCCCCCCCCCCCCEEECCCCEEECHHHHHHHHHCCCCCCCCEECCCCCCC
Fine-tuning with LoRA¶
Finte with LoRA is supported. In this regime, the encoder and the embedding layers are frozen, and LoRA weights are added to those layers. The classification and regression heads are not frozen. LoRA fine-tuning is supported for any of the classification types above. The outputted weights in the results directory only contain the LoRA weights and the classification and regression heads. For further inference and training, both the original model weights and fine-tuned weights are necessary.
%%capture --no-display --no-stderr cell_output
data_path = os.path.join(work_dir, "token_classification_data.csv")
! finetune_esm2 \
--restore-from-checkpoint-path {pretrain_checkpoint_path} \
--train-data-path {data_path} \
--valid-data-path {data_path} \
--config-class ESM2FineTuneTokenConfig \
--dataset-class InMemoryPerTokenValueDataset \
--task-type "classification" \
--cnn-dropout 0.25 \
--cnn-hidden-size 32 \
--cnn-num-classes 3 \
--experiment-name "lora-token-level-classification" \
--num-steps 50 \
--num-gpus 1 \
--val-check-interval 10 \
--log-every-n-steps 10 \
--encoder-frozen \
--lr 5e-3 \
--lr-multiplier 1e2 \
--scale-lr-layer "classification_head" \
--result-dir {work_dir} \
--micro-batch-size 2 \
--num-gpus 1 \
--precision "bf16-mixed" \
--lora-finetune
lora_checkpoint_path = (
f"{work_dir}/lora-token-level-classification/checkpoints/checkpoint-step=49-consumed_samples=100.0-last/weights"
)
results_path = f"{work_dir}/lora-token-level-classification/infer/"
print(results_path)
/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/infer/
%%capture --no-display --no-stderr cell_output
data_path = os.path.join(work_dir, "sequences.csv")
! infer_esm2 --checkpoint-path {pretrain_checkpoint_path} \
--config-class ESM2FineTuneTokenConfig \
--data-path {data_path} \
--results-path {results_path} \
--micro-batch-size 3 \
--num-gpus 1 \
--precision "bf16-mixed" \
--include-embeddings \
--include-hiddens \
--include-input-ids \
--lora-checkpoint-path {lora_checkpoint_path}