Important
You are viewing the NeMo 2.0 documentation. This release introduces significant changes to the API and a new library, NeMo Run. We are currently porting all features from NeMo 1.0 to 2.0. For documentation on previous versions or features not yet available in 2.0, please refer to the NeMo 24.07 documentation.
Fine-tuning Stable Diffusion with DRaFT+#
In this tutorial, we will go through the step-by-step guide for fine-tuning Stable Diffusion model using DRaFT+ algorithm by NVIDIA. DRaFT+ is an improvement over the DRaFT algorithm by alleviating the mode collapse and improving diversity through regularization. For more technical details on the DRaFT+ algorithm, check out our technical blog.
Data Input for running DRaFT+#
The data for running DRaFT+ should be a .tar
file consisting of a plain prompt. You can generate a tarfile from a .txt
file containing the prompts separated by new lines, such as following format:
prompt1
prompt2
prompt3
prompt4
...
Use the following script to download and save the prompts from the Pick a pic dataset:
from datasets import load_dataset dataset = load_dataset("yuvalkirstain/pickapic_v1_no_images") captions = dataset['train']['caption'] file_path = # path to save as a .txt file with open(file_path, 'w') as file: for caption in captions: file.write(caption + '\n')
You can then run the following snipet to convert it to a .tar
file:
import webdataset as wds txt_file_path = # Path for the input txt file tar_file_name = # Path for the output tar file with open(txt_file_path, 'r') as f: prompts = f.readlines() prompts = [item.strip() for item in prompts] sink = wds.TarWriter(tar_file_name) for index, line in enumerate(prompts): sink.write({ "__key__": "sample%06d" % index, "txt": line.strip(), }) sink.close()
Reward Model#
Currently, we only have support for Pickscore reward model. Since Pickscore is a CLIP-based model, you can use the conversion script from NeMo to convert it from huggingface to NeMo.
DRaFT+ Training#
To launch reward model training, you must have checkpoints for UNet and VAE of a trained Stable Diffusion model and a checkpoint for the Reward Model.
To run DRaFT+ on the terminal directly:
GPFS="/path/to/nemo-aligner-repo" TRAIN_DATA_PATH="/path/to/train_dataset.tar" UNET_CKPT="/path/to/unet_weights.ckpt" VAE_CKPT="/path/to/vae_weights.bin" RM_CKPT="/path/to/reward_model.nemo" torchrun --nproc_per_node=2 ${GPFS}/examples/mm/stable_diffusion/train_sd_draftp.py \ trainer.num_nodes=1 \ trainer.devices=2 \ model.micro_batch_size=1 \ model.global_batch_size=8 \ model.kl_coeff=0.2 \ model.optim.lr=0.0001 \ model.unet_config.from_pretrained=${UNET_CKPT} \ model.first_stage_config.from_pretrained=${VAE_CKPT} \ rm.model.restore_from_path=${RM_CKPT} \ model.data.trian.webdataset.local_root_path=${TRAIN_DATA_PATH} \ exp_manager.create_wandb_logger=False \ exp_manager.explicit_log_dir=/results
To run DRaFT+ using Slurm. The script below uses 1 node:
#!/bin/bash #SBATCH -A <<ACCOUNT NAME>> #SBATCH -p <<PARTITION NAME>> #SBATCH -N 4 #SBATCH -t 4:00:00 #SBATCH -J <<JOB NAME>> #SBATCH --ntasks-per-node=8 #SBATCH --exclusive #SBATCH --overcommit GPFS="/path/to/nemo-aligner-repo" GPFS="/path/to/nemo-aligner-repo" TRAIN_DATA_PATH="/path/to/train_dataset.tar" UNET_CKPT="/path/to/unet_weights.ckpt" VAE_CKPT="/path/to/vae_weights.bin" RM_CKPT="/path/to/reward_model.nemo" PROJECT="<<WANDB PROJECT>>" CONTAINER=<<<CONTAINER>>> # use the latest NeMo Training container, Aligner will work there MOUNTS="--container-mounts=MOUNTS" # mounts RESULTS_DIR="/path/to/result_dir" OUTFILE="${RESULTS_DIR}/rm-%j_%t.out" ERRFILE="${RESULTS_DIR}/rm-%j_%t.err" mkdir -p ${RESULTS_DIR} MOUNTS="--container-mounts=MOUNTS" # mounts read -r -d '' cmd <<EOF echo "*******STARTING********" \ && echo "---------------" \ && echo "Starting training" \ && cd ${GPFS} \ && export PYTHONPATH="${GPFS}:${PYTHONPATH}" \ && export HYDRA_FULL_ERROR=1 \ && python -u ${GPFS}/examples/nlp/gpt/train_reward_model.py \ trainer.num_nodes=1 \ trainer.devices=8 \ model.micro_batch_size=2 \ model.global_batch_size=16 \ model.kl_coeff=0.2 \ model.optim.lr=0.0001 \ model.unet_config.from_pretrained=${UNET_CKPT} \ model.first_stage_config.from_pretrained=${VAE_CKPT} \ rm.model.restore_from_path=${RM_CKPT} \ model.data.webdataset.local_root_path=${TRAIN_DATA_PATH} \ exp_manager.explicit_log_dir=${RESULTS_DIR} \ exp_manager.create_wandb_logger=True \ exp_manager.wandb_logger_kwargs.name=${NAME} \ exp_manager.wandb_logger_kwargs.project=${PROJECT} EOF srun -o $OUTFILE -e $ERRFILE --container-image=$CONTAINER $MOUNTS bash -c "${cmd}" set +x
Note
For more info on DRaFT+ hyperparameters please see the model config file:
NeMo-Aligner/examples/mm/stable_diffusion/conf/draftp_sd.yaml
DRaFT+ Results#
Once you have completed fine-tuning Stable Diffusion with DRaFT+, you can run inference on your saved model using the sd_infer.py and sd_lora_infer.py scripts from the NeMo codebase. The generated images with the fine-tuned model should have better prompt alignment and aesthetic quality.