Encoder Fine-tuning#

This notebook serves as a demo for implementing our EncoderFineTuning class from scratch, hooking up the data, setting up the configs, and creating a fine-tuning script.

Overview#

The task head plays a crucial role in fine-tuning for a downstream task. As a part of transfer learning, a pre-trained model is often utilized to learn generic features from a large-scale dataset. However, these features might not be directly applicable to the specific task at hand. By incorporating an MLP task head, which consists of one or more fully connected layers, the model can adapt and specialize to the target task. The MLP task head serves as a flexible and adaptable component that learns task-specific representations by leveraging the pre-trained features as a foundation. Through fine-tuning, the MLP task head enables the model to learn and extract task-specific patterns, improving performance and addressing the nuances of the downstream task. It acts as a critical bridge between the pre-trained model and the specific task, enabling efficient and effective transfer of knowledge.

To assist in creating your own task/prediction head, we have created the EncoderFinetuning abstract base class to help you to quickly implement a feed forward network for training on a downstream task.

Setup and Assumptions#

This tutorial assumes that a copy of the BioNeMo framework repo exists on workstation or server and has been mounted inside the container at /workspace/bionemo as described in the Code Development section of the Quickstart Guide. This path will be referred to with the variable BIONEMO_WORKSPACE in the tutorial.

All commands should be executed inside the BioNeMo docker container.

A user may create/place the following codes and execute files from $BIONEMO_WORKSPACE/example/<molecule_or_protein>/<model_name>/ folder, which needs to be adjusted according to the use case.

Getting Started#

In this notebook we will go over implementing our own downstream task model based on EncoderFinetuning, using a MLP regressor prediction head.

To successfully accomplish this we need to define some key classes/files:

  • Custom dataset class - defines functions to process our dataset and prepare batches

  • BioNeMo data module class - performs additional the data-driven functions such as creation of train/val/test datasets

  • Downstream Task Model class - extends the BioNeMo EncoderFinetuning class, which provides help abstract methods that help you define your prediction head architecture, loss function, pretrained model encoder that you want to fine-tune.

  • Config yaml file - to specify model parameters and control behavior of model at runtime

  • Training script - launches model training of our downtream task model

Data Setup#

Dataset class#

Create a dataset class by extending from torch.utils.data.Dataset or from BioNeMo’s dataset classes found in bionemo.data.datasets.

For the purposes of this demo, we’ll assume we are using the FreeSolv dataset from MoleculeNet to train our prediction, and our downstream task will be to predict the hyration free energy of small molecules in water. Therefore, the custom BioNeMo dataset class will be appropriate (found in bionemo.data.datasets.single_value_dataset.SingleValueDataset) as it faciliates predicting on a single value.

An excerpt from the class is shown below:

from torch.utils.data import Dataset

class SingleValueDataset(Dataset):
    def __init__(self, datafiles, max_seq_length, emb_batch_size=None, 
                 model=None, input_column: str='SMILES', 
                 target_column: str='y', task: str="regression")

SingleValueDataset accepts the path to the data, the column name of the input, the column name of the target values and other parameters. Simply extend SingleValueDataset class in a similar way to customize your class for your data.

Data module#

To coordinate the creation of training, validation and testing datasets from your data, we need to use a BioNeMo data module class. To do this we simply extend the BioNeMoDataModule class (located at bionemo.core.BioNeMoDataModule) which defines helpful abstract methods that use your dataset class. At minimum, we need to define our __init__(), train_dataset(), val_dataset(), test_dataset() when extending BioNeMoDataModule.

We have already done this and created the SingleValueDataModule (located at bionemo.data.datasets.single_value_dataset.SingleValueDataModule) for use with the SingleValueDataset. Make note of the use of our SingleValueDataset class inside the _create_dataset() function

class SingleValueDataModule(BioNeMoDataModule):
    def __init__(self, cfg, trainer, model):
        super().__init__(cfg, trainer)
        self.model = model
        self.parent_cfg = cfg
        if self.cfg.task_type not in ["regression", "classification"]:
            raise ValueError("Invalid task_type was provided {}. " + \
                             "Supported task_type: 'classification' and 'regression'".format(self.cfg.task))
        if self.cfg.task_type == "classification":
            self.tokenizer = Label2IDTokenizer()
        else:
            self.tokenizer = None

    def _update_tokenizer(self, tokenizer, labels):
        tokenizer = tokenizer.build_vocab(labels)
        return tokenizer

    # helper function for creating Datasets
    def _create_dataset(self, split, files):
        datafiles = os.path.join(self.cfg.dataset_path, 
                                 split, 
                                 files)
        datafiles = expand_dataset_paths(datafiles, ".csv")
        dataset = SingleValueDataset(
            datafiles=datafiles, 
            max_seq_length=self.parent_cfg.seq_length,
            emb_batch_size=self.cfg.emb_batch_size,
            model=self.model, 
            input_column=self.cfg.sequence_column, 
            target_column=self.cfg.target_column
            )
        if self.tokenizer is not None:
            self.tokenizer = self._update_tokenizer(
                self.tokenizer, 
                dataset.labels.reshape(-1, 1)
                )
            dataset.labels = get_data._tokenize_labels([self.tokenizer], dataset.labels.reshape(1, 1, -1), [self.cfg.num_classes])[0][0] 
        return dataset

    # uses our _create_dataset function to instantiate a training dataset
    def train_dataset(self):
        """Creates a training dataset
        Returns:
            Dataset: dataset to use for training
        """
        self.train_ds = self._create_dataset("train", 
                                             self.cfg.dataset.train)
        return self.train_ds

    def val_dataset(self):
        """Creates a validation dataset
        Returns:
            Dataset: dataset to use for validation
        """
        if "val" in self.cfg.dataset:
            self.val_ds = self._create_dataset("val", 
                                             self.cfg.dataset.val)
            return self.val_ds
        else:
            pass

    def test_dataset(self):
        """Creates a testing dataset
        Returns:
            Dataset: dataset to use for testing
        """
        if "test" in self.cfg.dataset:
            self.test_ds = self._create_dataset("test", 
                                                self.cfg.dataset.test)
            return self.test_ds
        else:
            pass

Setup Downstream Task Model Class#

Now that our dataset classes are created, we are ready to create the model class that will define the model architecture necessary to train on our downstream task. BioNeMo provides the EncoderFinetuning which allows us to quickly create a model for adding a prediction head to a pretrained model by quickly and easily extending the class and overriding abstract methods within the class.

Let’s create a class, DownstreamTaskModel, based on EncoderFinetuning where we will setup our task head and the encoder model. We will use our MLPModel class as well, however, you can implement your own model to use with your class.

It is important to note that we are required to implement the abstract methods withins the EncoderFinetuning.

import torch
import torch.nn as nn
import bionemo.utils
from functools import lru_cache
from nemo.utils.model_utils import import_class_by_path
from bionemo.model.core import MLPModel
from bionemo.model.core.encoder_finetuning import EncoderFineTuning

#import a BioNeMo data module or your custom data module
from bionemo.data.datasets.single_value_dataset import SingleValueDataModule

class DownstreamTaskModel(EncoderFineTuning):

    def __init__(self, cfg, trainer):
        super().__init__(cfg.model, trainer=trainer) 

        #store config parameters within object so they can be access easily
        self.full_cfg = cfg

        # we want our downstream model to behave differently based on whether the
        # encoder_frozen config parameter is set to True or False so we store it for 
        # convenient access within the object
        self.encoder_frozen = self.full_cfg.model.encoder_frozen
        self.batch_target_name = self.cfg.data.target_column

    def configure_optimizers(self):
        super().setup_optimization(optim_config=self.cfg.finetuning_optim)

        if self._scheduler is None:
            return self._optimizer
        else:
            return [self._optimizer], [self._scheduler]

    # use this function to define what the loss func of the task head should be
    def build_loss_fn(self):
        return bionemo.utils.lookup_or_use(torch.nn, self.cfg.downstream_task.loss_func)

    # define the architecture of our prediction task head for the downstream task
    def build_task_head(self):

        # we create an instance of MLPModel using parameters defined in the config file
        # choose the right task head architecture based on your downstream task (for example,. regression vs classification)
        regressor = MLPModel(layer_sizes=[self.encoder_model.cfg.model.hidden_size, self.cfg.downstream_task.hidden_layer_size, self.cfg.downstream_task.n_outputs],
            dropout=0.1,
        )

        # we can use pytorch libraries to further define our architecture and tensor operations
        task_head = nn.Sequential(regressor, nn.Flatten(start_dim=0))
        return task_head

    # returns the model from which we will use the pretrained encoder
    def setup_encoder_model(self, cfg, trainer):
        infer_class = import_class_by_path(self.full_cfg.infer_target)
        pretrained_model = infer_class(
            self.full_cfg, 
            freeze=self.encoder_frozen, #determines whether encoders weights are trainable
            restore_path=self.full_cfg.restore_from_path,
            training=not self.cfg.encoder_frozen)
        return pretrained_model

    # use this function to define all your data operations
    # in this example, we use the config parameter to determine the value of our model variable
    # then we pass it into an instance of SingleValueDataModule()
    @lru_cache
    def data_setup(self):
        if self.encoder_frozen:
            model = self.encoder_model
        else:
            model = None
        self.data_module = SingleValueDataModule(
            self.cfg, self.trainer, model=model
        )

    # ensures that we create our necessary datasets 
    def on_fit_start(self):
        self.build_train_valid_test_datasets()
        return super().on_fit_start()

    # function that simply instatiates our datasets and stores them within our object 
    def build_train_valid_test_datasets(self):
        self._train_ds = self.data_module.get_sampled_train_dataset()
        self._validation_ds = self.data_module.get_sampled_val_dataset()
        self._test_ds = self.data_module.get_sampled_test_dataset()

    # define the behavior for retrieving embeddings from encoder
    def encoder_forward(self, bart_model, batch: dict):
        if self.encoder_frozen:
            enc_output = batch["embeddings"]
        else:
            enc_output = bart_model.seq_to_embeddings(batch["embeddings"])
        return enc_output

    # define additional operations on the encoder output
    # in this example we simply convert the values of the tensor to float
    # see forward() in encoder_finetuning.py for additional information
    def extract_for_task_head(self, input_tensor):
        return input_tensor.float()
    
    def get_target_from_batch(self, batch):
        ret = batch['target']

        return ret.float()

Config YAML#

Now that we have our DownstreamTaskModel defined, let’s create a config yaml file (downstream_task_example.yaml) that will define specific values of tunable hyperparameters, file paths and other important parameters needed by our model.

An example config file can be found in examples/molecule/megamolbart/conf/finetune_config.yaml.

Most importantly, our config file:

  • provides the path to our pretrained model using the ‘restore_from_path’ parameter

  • the model parameters, including the loss_func, hidden_layer_size, n_outputs to be used by our prediction head

  • important data related parameters such as task_type, dataset_path, sequence_column, target_column

name: downstream_task_example
defaults: 
  - pretrain_small_span_aug
do_training: True # set to false if data preprocessing steps must be completed
do_testing: True # set to true to run evaluation on test data after training, requires test_dataset section
restore_from_path: /model/molecule/megamolbart/megamolbart.nemo
target: bionemo.model.molecule.megamolbart.MegaMolBARTModel
infer_target: bionemo.model.molecule.megamolbart.infer.MegaMolBARTInference

trainer:
  devices: 1 # number of GPUs or CPUs
  num_nodes: 1
  max_epochs: 100 # use max_steps instead with NeMo Megatron models
  max_steps: 10000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
  val_check_interval: 8 # set to integer when using steps to determine frequency of validation, use fraction with epochs
  limit_val_batches: 20 # number of batches in validation step, use fraction for fraction of data, 0 to disable
  limit_test_batches: 100 # number of batches in test step, use fraction for fraction of data, 0 to disable

exp_manager:
  wandb_logger_kwargs:
    project: ${name}_finetuning
    name: ${name}_finetuning_encoder_frozen_${model.encoder_frozen}
  checkpoint_callback_params:
    monitor: val_loss # use molecular accuracy to select best checkpoints
    mode: min # use min or max of monitored metric to select best checkpoints
    filename: '${name}-${model.name}--{val_loss:.2f}-{step}-{consumed_samples}'
  resume_if_exists: True

model:
  encoder_frozen: True
  post_process: False
  micro_batch_size: 32 # NOTE: adjust to occupy ~ 90% of GPU memory
  global_batch_size: null
  tensor_model_parallel_size: 1  # model parallelism
  
  downstream_task:
    n_outputs: 1
    hidden_layer_size: 128
    loss_func: MSELoss

  data:
    # Finetuning data params
    task_type: 'regression'
    dataset_path: /data/physchem/SAMPL
    sequence_column: 'smiles'
    target_column: 'expt'
    emb_batch_size: ${model.micro_batch_size}
    dataset:
      train: x000
      val: x000
      test: x000
    num_workers: 8
  
  finetuning_optim:
    name: adam
    lr: 0.001
    betas:
      - 0.9
      - 0.999
    eps: 1e-8
    weight_decay: 0.01
    sched:
      name: WarmupAnnealing
      min_lr: 0.00001
      last_epoch: -1
      warmup_steps: 100

Training Script#

Finally we’ll need a training script to launch our model training

from nemo.core.config import hydra_runner
from nemo.utils import logging
from omegaconf.omegaconf import OmegaConf
from bionemo.model.utils import (setup_trainer,)

import DownstreamTaskModel #import our model class

@hydra_runner(config_path="conf", config_name="downstream_task_example") 
def main(cfg) -> None:

    logging.info("\n\n************* Finetune config ****************")
    logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

    trainer = setup_trainer(
         cfg, builder=None)

    # we instantiate our model 
    model = DownstreamTaskModel(cfg, trainer)

    if cfg.do_training:
        logging.info("************** Starting Training ***********")
        trainer.fit(model) # train our downstream task model using the dataset defined in config
        logging.info("************** Finished Training ***********")

    if cfg.do_testing:
        if "test" in cfg.model.data.dataset:
            trainer.test(model)
        else:
            raise UserWarning("Skipping testing, test dataset file was not provided. Specify 'test_ds.data_file' in yaml config")

if __name__ == '__main__':
    main()

We can launch our training by simply calling:

python training_script.py

More examples of training models for downstream tasks in BioNeMo can be found in our physicochemical property prediction notebook here.