Important
You are viewing the NeMo 2.0 documentation. This release introduces significant changes to the API and a new library, NeMo Run. We are currently porting all features from NeMo 1.0 to 2.0. For documentation on previous versions or features not yet available in 2.0, please refer to the NeMo 24.07 documentation.
PEFT in NeMo 2.0#
Model Customization#
Customizing models enables you to adapt a general pre-trained LLM to a specific use case or domain. This process results in a fine-tuned model that benefits from the extensive pretraining data, while also yielding more accurate outputs for the specific downstream task. Model customization is achieved through supervised fine-tuning and falls into two popular categories:
Full-Parameter Fine-Tuning, which is referred to as Supervised Fine-Tuning (SFT) in NeMo
Parameter-Efficient Fine-Tuning (PEFT)
In SFT, all of the model parameters are updated to produce outputs that are adapted to the task.
PEFT, on the other hand, tunes a much smaller number of parameters which are inserted into the base model at strategic locations. When fine-tuning with PEFT, the base model weights remain frozen, and only the adapter modules are trained. As a result, the number of trainable parameters is significantly reduced, often to less than 1%.
While SFT often yields the best possible results, PEFT methods can often achieve nearly the same degree of accuracy, while significantly reducing the computational cost. As language models continue to grow in size, PEFT is gaining popularity due to its lightweight requirements on training hardware.
NeMo 2.0 supports SFT and two PEFT methods which can be used with various transformer-based models.
SFT |
LoRA |
DoRA |
|
---|---|---|---|
Baichuan 2 7B |
✅ |
✅ |
✅ |
ChatGLM 3 6B |
✅ |
✅ |
✅ |
Gemma 2B/7B |
✅ |
✅ |
✅ |
Gemma 2 9B/27B |
✅ |
✅ |
✅ |
Llama 3 8B/70B |
✅ |
✅ |
✅ |
Llama 3.1 8B/70B/405B |
✅ |
✅ |
✅ |
Mistral 7B |
✅ |
✅ |
✅ |
Mixtral 8x7B/8x22B |
✅ |
✅ |
✅ |
Nemotron 3 4B/8B |
✅ |
✅ |
✅ |
Nemotron 4 15B/22B/340B |
✅ |
✅ |
✅ |
Qwen 2 0.5B/1.5B/7B/72B |
✅ |
✅ |
✅ |
Starcoder 15B |
✅ |
✅ |
✅ |
Starcoder 2 3B/7B/15B |
✅ |
✅ |
✅ |
Read more about supported PEFT methods here:
Run PEFT Training in NeMo 2.0#
Below are three examples of running a simple PEFT training loop for the Llama 3.2 1B model using NeMo 2.0. These examples showcase different levels of abstraction provided by the NeMo Framework. Once you have set up your environment following the instructions in Install NeMo Framework, you are ready to run the simple PEFT tuning script.
The easiest way to run PEFT training is with the recipe files. You can find the list of supported models and their predefined recipes here.
Note
Prerequisite: Before proceeding, please follow the example in Quickstart with NeMo-Run to familiarize yourself with NeMo-Run first.
from nemo.collections import llm
import nemo_run as run
nodes = 1
gpus_per_node = 1
recipe = llm.llama32_1b.finetune_recipe(
dir="/checkpoints/llama3.2_1b", # Path to store checkpoints
name="llama3_lora",
num_nodes=nodes,
num_gpus_per_node=gpus_per_node,
peft_scheme="lora",
)
# Note: "lora" is the default peft_scheme.
# Supported values are "lora", "dora", "none"/None (full fine-tuning)
# Override your PEFT configuration here, if needed. For example:
recipe.peft.target_modules = ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"]
recipe.peft.dim = 16
recipe.peft.alpha = 32
# Add other overrides here
...
run.run(recipe)
You can use PEFT recipes via the NeMo Run CLI (See here for more
details). LoRA and DoRA are registered as factory classes, so you can specify peft=<lora/dora/none>
directly in the terminal.
This provides a quick and easy way to launch training jobs when you do not need to override
any configuration from the default recipes.
nemo llm finetune -f llama32_1b peft=lora # acceptable values are lora/dora/none
This example uses the finetune API from the NeMo Framework LLM collection. This is a lower-level API that allows you to lay out the various configurations in a Pythonic fashion. This gives you the greatest amount of control over each configuration.
import torch
from nemo import lightning as nl
from nemo.collections import llm
from megatron.core.optimizer import OptimizerConfig
if __name__ == "__main__":
seq_length = 2048
global_batch_size = 16
## setup a finetuning dataset
data = llm.DollyDataModule(
seq_length=seq_length,
global_batch_size=global_batch_size
)
## initialize a small Llama model
llama_config = llm.Llama32Config1B()
model = llm.LlamaModel(llama_config, tokenizer=data.tokenizer)
## initialize the strategy
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
pipeline_dtype=torch.bfloat16,
)
## setup the optimizer
opt_config = OptimizerConfig(
optimizer='adam',
lr=1e-4,
bf16=True,
)
opt = nl.MegatronOptimizerModule(config=opt_config)
trainer = nl.Trainer(
devices=1, ## you can change the number of devices to suit your setup
max_steps=50,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)
nemo_logger = nl.NeMoLogger(
log_dir="test_logdir", ## logs and checkpoints will be written here
)
peft = llm.peft.LoRA(dim=8, alpha=16)
resume = nl.AutoResume(
restore_config=nl.RestoreConfig(path="nemo://meta-llama/Llama-3.2-1B"),
)
# only need to import the first time script is run
llm.import_ckpt(model, "hf://meta-llama/Llama-3.2-1B")
llm.finetune(
model=model,
data=data,
trainer=trainer,
peft=peft,
log=nemo_logger,
optim=opt,
resume=resume,
)
Run PEFT Inference in NeMo 2.0#
Inference with adapters is supported natively with the llm.generate
API.
Simply replace the path to the full model with the path to a PEFT checkpoint.
NeMo will infer all the information required to run inference from the checkpoint,
including the model type, adapter type, base model checkpoint path, etc.
See the llm.generate
API for more details.
The following is an example script:
from megatron.core.inference.common_inference_params import CommonInferenceParams
import nemo.lightning as nl
from nemo.collections.llm import api
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
context_parallel_size=1,
sequence_parallel=False,
setup_optimizers=False,
)
trainer = nl.Trainer(
accelerator="gpu",
devices=1,
num_nodes=1,
strategy=strategy,
plugins=nl.MegatronMixedPrecision(
precision="bf16-mixed",
params_dtype=torch.bfloat16,
pipeline_dtype=torch.bfloat16,
),
)
prompts = [
"Hello, how are you?",
"How many r's are in the word 'strawberry'?",
"Which number is bigger? 10.119 or 10.19?",
]
if __name__ == "__main__":
adapter_checkpoint = "/path/to/nemo_ckpt" # a folder that contains "weights" and "context" subfolders
results = api.generate(
path=adapter_checkpoint,
prompts=prompts,
trainer=trainer,
inference_params=CommonInferenceParams(temperature=0.1, top_k=10, num_tokens_to_generate=512),
text_only=True,
)
Hint
The adapter checkpoint only contains adapter weights, not base model weights. So why do we not need to provide the path to the base model?
This is because the adapter checkpoint also contains a reference to the path of the base model it is trained
with. Each adapter checkpoint has to be paired with the exact base model it is trained with, so a reference
to it is stored in weights/adapter_metadata.json
along with the adapter weights.
Therefore, if you are sharing an adapter checkpoint with someone on a different
file system, you must ensure the recipient updates weights/adapter_metadata.json
to a valid
path on their file system.
Hint
Depending on the size of your model, you may need to adjust tensor_model_parallel_size
,
pipeline_model_parallel_size
, num_devices
and num_nodes
. If you are unsure what
parallelism configs to set, the PEFT training recipe for that model will provide a good upper bound.
See here for a list of recipes.
Merge LoRA Weights with Base Model#
When you want to run PEFT inference without changing the model architecture, we also support merging trained
LoRA weights back to the base model. This is supported by llm.peft.merge_lora
API.
from nemo.collections import llm
if __name__ == '__main__':
llm.peft.merge_lora(
lora_checkpoint_path="path/to/lora_checkpoint",
output_path="path/to/merged_checkpoint",
)
Export LoRA Weights to Hugging Face#
Exporting LoRA checkpoints is supported by the llm.export_ckpt
API, provided that a PEFT exporter (hf-peft
)
is implemented for the model.
To implement a PEFT exporter for your own model class, follow the example in
Llama.
from nemo.collections import llm
if __name__ == '__main__':
llm.export_ckpt(
path=Path("path/to/lora_checkpoint"),
target="hf-peft",
output_path=Path("path/to/output_HF_checkpoint"),
)
Note that the Hugging Face implementation is equivalent to NeMo’s CanonicalLoRA
, not LoRA
. However both
can be converted to the Hugging Face implementation. Read more about the difference here:
Explore PEFT Design in NeMo 2.0#
If you are developing for NeMo PEFT, you are invited to read more about the design of PEFT in NeMo 2.0 here.