Fine-tuning Stable Diffusion with DRaFT+

In this tutorial, we will go through the step-by-step guide for fine-tuning Stable Diffusion model using DRaFT+ algorithm by NVIDIA. DRaFT+ is an improvement over the DRaFT algorithm by alleviating the mode collapse and improving diversity through regularization. For more technical details on the DRaFT+ algorithm, check out our technical blog.

Data Input for running DRaFT+

The data for running DRaFT+ should be a .tar file consisting of a plain prompt. You can generate a tarfile from a .txt file containing the prompts separated by new lines, such as following format:

prompt1
prompt2
prompt3
prompt4
...

Use the following script to download and save the prompts from the Pick a pic dataset:

from datasets import load_dataset

dataset = load_dataset("yuvalkirstain/pickapic_v1_no_images")
captions = dataset['train']['caption']
file_path = # path to save as a .txt file
with open(file_path, 'w') as file:
    for caption in captions:
        file.write(caption + '\n')

You can then run the following snipet to convert it to a .tar file:

import webdataset as wds

txt_file_path = # Path for the input txt file
tar_file_name = # Path for the output tar file

with open(txt_file_path, 'r') as f:
    prompts = f.readlines()
prompts = [item.strip() for item in prompts]
sink = wds.TarWriter(tar_file_name)
for index, line in enumerate(prompts):
    sink.write({
        "__key__": "sample%06d" % index,
        "txt": line.strip(),
    })
sink.close()

Reward Model

Currently, we only have support for Pickscore reward model. Since Pickscore is a CLIP-based model, you can use the conversion script from NeMo to convert it from huggingface to NeMo.

DRaFT+ Training

To launch reward model training, you must have checkpoints for UNet and VAE of a trained Stable Diffusion model and a checkpoint for the Reward Model.

To run DRaFT+ on the terminal directly:

GPFS="/path/to/nemo-aligner-repo"
TRAIN_DATA_PATH="/path/to/train_dataset.tar"
UNET_CKPT="/path/to/unet_weights.ckpt"
VAE_CKPT="/path/to/vae_weights.bin"
RM_CKPT="/path/to/reward_model.nemo"

torchrun --nproc_per_node=2 ${GPFS}/examples/mm/stable_diffusion/train_sd_draftp.py \
   trainer.num_nodes=1 \
   trainer.devices=2 \
   model.micro_batch_size=1 \
   model.global_batch_size=8 \
   model.kl_coeff=0.2 \
   model.optim.lr=0.0001 \
   model.unet_config.from_pretrained=${UNET_CKPT} \
   model.first_stage_config.from_pretrained=${VAE_CKPT} \
   rm.model.restore_from_path=${RM_CKPT} \
   model.data.trian.webdataset.local_root_path=${TRAIN_DATA_PATH} \
   exp_manager.create_wandb_logger=False \
   exp_manager.explicit_log_dir=/results

To run DRaFT+ using Slurm. The script below uses 1 node:

#!/bin/bash
#SBATCH -A <<ACCOUNT NAME>>
#SBATCH -p <<PARTITION NAME>>
#SBATCH -N 4
#SBATCH -t 4:00:00
#SBATCH -J <<JOB NAME>>
#SBATCH --ntasks-per-node=8
#SBATCH --exclusive
#SBATCH --overcommit

GPFS="/path/to/nemo-aligner-repo"

GPFS="/path/to/nemo-aligner-repo"
TRAIN_DATA_PATH="/path/to/train_dataset.tar"
UNET_CKPT="/path/to/unet_weights.ckpt"
VAE_CKPT="/path/to/vae_weights.bin"
RM_CKPT="/path/to/reward_model.nemo"

PROJECT="<<WANDB PROJECT>>"

CONTAINER=<<<CONTAINER>>> # use the latest NeMo Training container, Aligner will work there

MOUNTS="--container-mounts=MOUNTS" # mounts

RESULTS_DIR="/path/to/result_dir"

OUTFILE="${RESULTS_DIR}/rm-%j_%t.out"
ERRFILE="${RESULTS_DIR}/rm-%j_%t.err"
mkdir -p ${RESULTS_DIR}

MOUNTS="--container-mounts=MOUNTS" # mounts

read -r -d '' cmd <<EOF
echo "*******STARTING********" \
&& echo "---------------" \
&& echo "Starting training" \
&& cd ${GPFS} \
&& export PYTHONPATH="${GPFS}:${PYTHONPATH}" \
&& export HYDRA_FULL_ERROR=1 \
&& python -u ${GPFS}/examples/nlp/gpt/train_reward_model.py \
   trainer.num_nodes=1 \
   trainer.devices=8 \
   model.micro_batch_size=2 \
   model.global_batch_size=16 \
   model.kl_coeff=0.2 \
   model.optim.lr=0.0001 \
   model.unet_config.from_pretrained=${UNET_CKPT} \
   model.first_stage_config.from_pretrained=${VAE_CKPT} \
   rm.model.restore_from_path=${RM_CKPT} \
   model.data.webdataset.local_root_path=${TRAIN_DATA_PATH} \
   exp_manager.explicit_log_dir=${RESULTS_DIR} \
   exp_manager.create_wandb_logger=True \
   exp_manager.wandb_logger_kwargs.name=${NAME} \
   exp_manager.wandb_logger_kwargs.project=${PROJECT}
EOF

srun -o $OUTFILE -e $ERRFILE --container-image=$CONTAINER $MOUNTS bash -c "${cmd}"
set +x

Note

For more info on DRaFT+ hyperparameters please see the model config file:

NeMo-Aligner/examples/mm/stable_diffusion/conf/draftp_sd.yaml

DRaFT+ Results

Once you have completed fine-tuning Stable Diffusion with DRaFT+, you can run inference on your saved model using the sd_infer.py and sd_lora_infer.py scripts from the NeMo codebase. The generated images with the fine-tuned model should have better prompt alignment and aesthetic quality.