Important
You are viewing the NeMo 2.0 documentation. This release introduces significant changes to the API and a new library, NeMo Run. We are currently porting all features from NeMo 1.0 to 2.0. For documentation on previous versions or features not yet available in 2.0, please refer to the NeMo 24.07 documentation.
Important
Before starting this tutorial, be sure to review the introduction for tips on setting up your NeMo-Aligner environment.
If you run into any problems, refer to NeMo’s Known Issues page. The page enumerates known issues and provides suggested workarounds where appropriate.
After completing this tutorial, refer to the evaluation documentation for tips on evaluating a trained model.
Fine-Tuning Stable Diffusion with DRaFT+#
In this tutorial, we will go through the step-by-step guide for fine-tuning a Stable Diffusion model using DRaFT+ algorithm by NVIDIA. DRaFT+ enhances the DRaFT DRaFT algorithm by mitigating mode collapse and improving diversity through regularization. For more technical details on the DRaFT+ algorithm, check out our technical blog.
Data Input for Running DRaFT+#
The data for running DRaFT+ should be a .tar
file consisting of a plain prompt. You can generate a tar file from a .txt
file containing the prompts separated by new lines, such as following format:
prompt1
prompt2
prompt3
prompt4
...
Use the following script to download and save the prompts from the Pick a pic dataset:
from datasets import load_dataset dataset = load_dataset("yuvalkirstain/pickapic_v1_no_images") captions = dataset['train']['caption'] file_path = # path to save as a .txt file with open(file_path, 'w') as file: for caption in captions: file.write(caption + '\n')
You can then run the following snippet to convert it to a .tar
file:
import webdataset as wds txt_file_path = # Path for the input txt file tar_file_name = # Path for the output tar file with open(txt_file_path, 'r') as f: prompts = f.readlines() prompts = [item.strip() for item in prompts] sink = wds.TarWriter(tar_file_name) for index, line in enumerate(prompts): sink.write({ "__key__": "sample%06d" % index, "txt": line.strip(), }) sink.close()
Reward Model#
Currently, we only have support for Pickscore-style reward models (PickScore/HPSv2). Since Pickscore is a CLIP-based model, you can use the conversion script from NeMo to convert it from huggingface to NeMo.
DRaFT+ Training#
To start reward model training, you need checkpoints for both the UNet and VAE components of a trained Stable Diffusion model, as well as a checkpoint for the Reward Model.
To run DRaFT+ on the terminal directly:
GPFS="/path/to/nemo-aligner-repo" TRAIN_DATA_PATH="/path/to/train_dataset.tar" UNET_CKPT="/path/to/unet_weights.ckpt" VAE_CKPT="/path/to/vae_weights.bin" RM_CKPT="/path/to/reward_model.nemo" DRAFTP_SCRIPT="train_sd_draftp.py" # or train_sdxl_draftp.py torchrun --nproc_per_node=2 ${GPFS}/examples/mm/stable_diffusion/${DRAFTP_SCRIPT} \ trainer.num_nodes=1 \ trainer.devices=2 \ model.micro_batch_size=1 \ model.global_batch_size=8 \ model.kl_coeff=0.2 \ model.optim.lr=0.0001 \ model.unet_config.from_pretrained=${UNET_CKPT} \ model.first_stage_config.from_pretrained=${VAE_CKPT} \ rm.model.restore_from_path=${RM_CKPT} \ model.data.train.webdataset.local_root_path=${TRAIN_DATA_PATH} \ exp_manager.create_wandb_logger=False \ exp_manager.explicit_log_dir=/results
To run DRaFT+ using Slurm. The script below uses 1 node:
#!/bin/bash #SBATCH -A <<ACCOUNT NAME>> #SBATCH -p <<PARTITION NAME>> #SBATCH -N 4 #SBATCH -t 4:00:00 #SBATCH -J <<JOB NAME>> #SBATCH --ntasks-per-node=8 #SBATCH --exclusive #SBATCH --overcommit GPFS="/path/to/nemo-aligner-repo" GPFS="/path/to/nemo-aligner-repo" TRAIN_DATA_PATH="/path/to/train_dataset.tar" UNET_CKPT="/path/to/unet_weights.ckpt" VAE_CKPT="/path/to/vae_weights.bin" RM_CKPT="/path/to/reward_model.nemo" PROJECT="<<WANDB PROJECT>>" CONTAINER=<<<CONTAINER>>> # use the latest NeMo Training container, Aligner will work there MOUNTS="--container-mounts=MOUNTS" # mounts RESULTS_DIR="/path/to/result_dir" OUTFILE="${RESULTS_DIR}/rm-%j_%t.out" ERRFILE="${RESULTS_DIR}/rm-%j_%t.err" mkdir -p ${RESULTS_DIR} MOUNTS="--container-mounts=MOUNTS" # mounts DRAFTP_SCRIPT="train_sd_draftp.py" # or train_sdxl_draftp.py read -r -d '' cmd <<EOF echo "*******STARTING********" \ && echo "---------------" \ && echo "Starting training" \ && cd ${GPFS} \ && export PYTHONPATH="${GPFS}:${PYTHONPATH}" \ && export HYDRA_FULL_ERROR=1 \ && python -u ${GPFS}/examples/mm/stable_diffusion/${DRAFTP_SCRIPT} \ trainer.num_nodes=1 \ trainer.devices=8 \ model.micro_batch_size=2 \ model.global_batch_size=16 \ model.kl_coeff=0.2 \ model.optim.lr=0.0001 \ model.unet_config.from_pretrained=${UNET_CKPT} \ model.first_stage_config.from_pretrained=${VAE_CKPT} \ rm.model.restore_from_path=${RM_CKPT} \ model.data.webdataset.local_root_path=${TRAIN_DATA_PATH} \ exp_manager.explicit_log_dir=${RESULTS_DIR} \ exp_manager.create_wandb_logger=True \ exp_manager.wandb_logger_kwargs.name=${NAME} \ exp_manager.wandb_logger_kwargs.project=${PROJECT} EOF srun --no-container-mount-home -o $OUTFILE -e $ERRFILE --container-image=$CONTAINER $MOUNTS bash -c "${cmd}" set +x
Note
For more information on DRaFT+ hyperparameters, please see the model config files (for SD and SDXL respectively):
NeMo-Aligner/examples/mm/stable_diffusion/conf/draftp_sd.yaml
NeMo-Aligner/examples/mm/stable_diffusion/conf/draftp_sdxl.yaml
DRaFT+ Results#
Once you have completed fine-tuning Stable Diffusion with DRaFT+, you can run inference on your saved model using the sd_infer.py and sd_lora_infer.py scripts from the NeMo codebase. The generated images with the fine-tuned model should have better prompt alignment and aesthetic quality.
User-controllable Fine-Tuning with Annealed Importance Guidance (AIG)#
AIG provides the inference-time flexibility to interpolate between the base Stable Diffusion model (with low rewards and high diversity) and a DRaFT+ fine-tuned model (with high rewards and low diversity) to obtain images with high rewards and high diversity. AIG inference is easily done by specifying comma-separated weight_type
strategies to interpolate between the base and fine-tuned model.
Weight type of base uses the base model for AIG, draft uses the finetuned model (no interpolation is done in either case). Weight type of the form power_<float> interpolates using an exponential decay specified in the AIG paper.
To run AIG inference on the terminal directly:
NUMNODES=1 LR=${LR:=0.00025} INF_STEPS=${INF_STEPS:=25} KL_COEF=${KL_COEF:=0.1} ETA=${ETA:=0.0} DATASET=${DATASET:="pickapic50k.tar"} MICRO_BS=${MICRO_BS:=1} GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4} PEFT=${PEFT:="sdlora"} NUM_DEVICES=${NUM_DEVICES:=8} GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION*NUMNODES)) LOG_WANDB=${LOG_WANDB:="False"} echo "additional kwargs: ${ADDITIONAL_KWARGS}" WANDB_NAME=SDXL_Draft_annealing WEBDATASET_PATH=/path/to/${DATASET} CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf" CONFIG_NAME=${CONFIG_NAME:="draftp_sdxl"} UNET_CKPT="/path/to/unet.ckpt" VAE_CKPT="/path/to/vae.ckpt" RM_CKPT="/path/to/reward_model.nemo" PROMPT=${PROMPT:="Bananas growing on an apple tree"} DIR_SAVE_CKPT_PATH=/path/to/explicit_log_dir if [ ! -z "${ACT_CKPT}" ]; then ACT_CKPT="model.activation_checkpointing=$ACT_CKPT " echo $ACT_CKPT fi EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sdxl.py"} export DEVICE="0,1,2,3,4,5,6,7" && echo "Running DRaFT+ on ${DEVICE}" && export HYDRA_FULL_ERROR=1 set -x CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=$NUM_DEVICES /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \ --config-path=${CONFIG_PATH} \ --config-name=${CONFIG_NAME} \ model.optim.lr=${LR} \ model.optim.weight_decay=0.0005 \ model.optim.sched.warmup_steps=0 \ model.sampling.base.steps=${INF_STEPS} \ model.kl_coeff=${KL_COEF} \ model.truncation_steps=1 \ trainer.draftp_sd.max_epochs=5 \ trainer.draftp_sd.max_steps=10000 \ trainer.draftp_sd.save_interval=200 \ trainer.draftp_sd.val_check_interval=20 \ trainer.draftp_sd.gradient_clip_val=10.0 \ model.micro_batch_size=${MICRO_BS} \ model.global_batch_size=${GLOBAL_BATCH_SIZE} \ model.peft.peft_scheme=${PEFT} \ model.data.webdataset.local_root_path=$WEBDATASET_PATH \ rm.model.restore_from_path=${RM_CKPT} \ trainer.devices=${NUM_DEVICES} \ trainer.num_nodes=${NUMNODES} \ rm.trainer.devices=${NUM_DEVICES} \ rm.trainer.num_nodes=${NUMNODES} \ +prompt="${PROMPT}" \ exp_manager.create_wandb_logger=${LOG_WANDB} \ model.first_stage_config.from_pretrained=${VAE_CKPT} \ model.first_stage_config.from_NeMo=True \ model.unet_config.from_pretrained=${UNET_CKPT} \ model.unet_config.from_NeMo=True \ $ACT_CKPT \ exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \ exp_manager.resume_if_exists=True \ exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \ exp_manager.wandb_logger_kwargs.project=${PROJECT} +weight_type='draft,base,power_2.0'