Important

NeMo 2.0 is an experimental feature and currently released in the dev container only: nvcr.io/nvidia/nemo:dev. Please refer to the Migration Guide for information on getting started.

Parameter Efficient Fine-Tuning (PEFT)

NeMo 2.0 introduces a complete overhaul of Parameter Efficient Fine-Tuning (PEFT). The new design formulates PEFT as a Model Transform that freezes the base model and inserts trainable adapters at specific locations within the model.

The hierarchy of class objects is as follows:

graph LR A(LoRA) -->B(PEFT) B --> C(ModelTransform) C --> D(lightning.Callback) E("... \n(Other PEFT\n methods)") --> B G("... \n(Other model \ncustomizations)") --> C

ModelTransform Class

The backbone of PEFT is the Model Transform mechanism, which is a PyTorch Lightning callback that mutates the model architecture at the start of fitting or validation.

The callback is designed to apply a transformation function to the model when fitting or validation begins, not when the model is first initialized. This design allows for loading the original checkpoint first and then applying the transformation, which is particularly useful for techniques like PEFT.

The transformation function is expected to be defined on the LightningModule as an attribute called model_transform.

In addition to PEFT, Model Transform allows other customizations to models without modifying or duplicating the PyTorch model source code in Megatron Core. Here are some examples:

  • Using a custom embedding layer with output scaling (e.g. Gemma models).

  • Using a custom attention layer with logit softcapping (e.g Gemma 2 models).

  • Adding a classification or embedding head to the output.

PEFT Class

The PEFT class defines an interface with functionalities common to all PEFT methods, such as:

  • Freezing the base model weights

  • Saving only trainable weights to the checkpoint

  • Loading two checkpoints (base model and adapter) at inference time.

Moreover, the PEFT class applies the transform with the walk function of the NeMo 2.0 API. This recursively iterates over all modules in a model and substitutes certain modules with user-defined modules. The substitution criteria can be based on the module name, module prefix, or module index.

As such, the PEFT class should not be used directly, but instead subclassed with the specific method which implements the transform() abstract method.

LoRA Class (or Any Other PEFT Method)

As an example, the LoRA transform function substitutes a linear layer with a “LoRA linear” layer that has a parallel computation path to calculate adapter outputs (which serves as the low-rank update to the linear layer). This “LoRA linear” layer (called AdapterParallelAdd) is defined separately in NeMo 2.0 for each PEFT method.

To determine if substitution should occur on a given module, the LoRA transformation function compares the current module name against a list of target_modules. By default, this list includes all four linear modules in a transformer layer.

In addition to the transformation function, the LoRA class also serves as a dataclass that contains configurations pertaining to LoRA, such as the list of target_modules, the LoRA dimension, LoRA dropout probability, etc.

Usage

As an example, here is how to set up a LoRA training run.

First, specify the LoRA configuration you would like to run, such as the rank of the adapters, the target linear layers to apply LoRA, any dropout, etc.

from nemo.collections import llm
lora = llm.peft.LoRA(target_modules=['linear_qkv', 'linear_proj'], dim=32, dropout=0.0)

Then, pass in the initialized lora object into both the model and the trainer as a ModelTransform callback

model = llm.Mistral7BModel(model_transform=lora)
trainer = nl.Trainer(..., callbacks=[lora])

You are now ready to launch LoRA training!

trainer.fit(model, data)