Finetune pre-trained models in BioNeMo#

This example covers general fine-tuning capability of BioNeMo framework and also the NeMo framework, which BioNeMo is based on.

Transfer learning is an important machine learning technique that uses a model’s knowledge of one task to make it perform better on another. Fine-tuning is one of the techniques to perform transfer learning. It is an essential part of the recipe for many state-of-the-art results where a base model is first pretrained on a task with abundant training data and then fine-tuned on different tasks of interest where the training data is less abundant or even scarce.

Setup and Assumptions#

This tutorial assumes that a copy of the BioNeMo framework repo exists on workstation or server and has been mounted inside the container at /workspace/bionemo as described in the Code Development section of the Quickstart Guide. This path will be referred to with the variable BIONEMO_WORKSPACE in the tutorial.

All commands should be executed inside the BioNeMo docker container.

Finetune MMB model in BioNeMo#

Finetuning Configuration

BioNeMo framework supports easy fine-tuning by partially/fully loading the pretrained weights from a checkpoint into the currently instantiated model. Note that the currently instantiated model should have parameters that match the pre-trained checkpoint (such that weights may load properly).

Pre-trained weights can be provided using a path to a NeMo model (via restore_from_path). This is done through:

  • adding restore_from_path to the config yaml

  • passing restore_from_path as a command line argument into your script

python examples/molecule/megamolbart/pretrain.py \
    --config-path=<path to dir of configs> \
    --config-name=<name of config without .yaml>) \
    do_training=True \
    ++model.data.dataset.train=<data files> \ # x000 for a single file for x_OP_000..186_CL_ for a range
    ++model.data.dataset.val=<data files> \
    ++model.data.dataset.test=<data files> \
    trainer.devices=$NGC_GPUS_PER_NODE \
    trainer.accelerator='gpu' \
    restore_from_path="<path to .nemo model file>"

Conveniently, we can follow this approach to finetune any other BioNeMo model as well. Simply change the training script path to the model of interest. For example, to finetune ESM-1nv model:

python examples/protein/esm1nv/pretrain.py \
    --config-path=<path to dir of configs> \
    --config-name=<name of config without .yaml>) \
    do_training=True \
    ++model.data.dataset.train=<data files> \ # x000 for a single file for x_OP_000..186_CL_ for a range
    ++model.data.dataset.val=<data files> \
    ++model.data.dataset.test=<data files> \
    trainer.devices=$NGC_GPUS_PER_NODE \
    trainer.accelerator='gpu' \
    restore_from_path="<path to .nemo model file>"

Note

It is important to have the dataset intended for model fine-tuning process in the format compatible with the model training datasets. For example, SMILES ot FASTA formats for small molecules and proteins, respectively. Mismatch in the expected dataset format could result into pickle errors, such as the following: _pickle.PicklingError: Can't pickle <class 'Boost.Python.ArgumentError'>: import of module 'Boost.Python' failed

Loading pretrained model#

Within a BioNeMo training script, to load a file for the purposes for fine-tuning we must use both:

  • the restore_from() method from NeMo

  • BioNeMoSaveRestoreConnector

These instructions are for loading fully trained checkpoints for fine-tuning. For resuming an unfinished training experiment, use the Experiment Manager to do so by setting the resume_if_exists flag to True.

For more granular control over how resuming from a pretrained model is done, we created the BioNeMoSaveRestoreConnector. Based on the NeMo NLPSaveRestoreConnector, this allows for changes in the embedding matrix. In conjunction with the NeMo restore_from() method, you can set vocabulary size at the time of loading our model with BioNeMoSaveRestoreConnector if needed. An example can be found in our pretrain.py script for ProtT5-nv (examples/protein/prott5nv/pretrain.py)

model = ProtT5nvModel.restore_from(
    cfg.restore_from_path, cfg.model, trainer=trainer,
    # 128 -- is the number of padded vocabulary in MegatronT5Model
    save_restore_connector=BioNeMoSaveRestoreConnector(vocab_size=128),
    # support loading weights with mismatch in embeddings (for example, alibi)
    strict=False,
)