Important

NeMo 2.0 is an experimental feature and currently released in the dev container only: nvcr.io/nvidia/nemo:dev. Please refer to the Migration Guide for information on getting started.

DreamBooth

Model Introduction

DreamBooth [MM-MODELS-DB2] is a fine-tuning technique and a solution to personalize large diffusion models like Stable Diffusion, which are powerful but lack the ability to mimic subjects of a given reference set. With DreamBooth, you only need a few images of a specific subject to fine-tune a pretrained text-to-image model, so that it learns to bind a unique identifier with a special subject. This unique identifier can then be used to synthesize fully-novel photorealistic images of the subject contextualized in different scenes.

NeMo’s Dreambooth is built upon the Stable Diffusion framework. While its architecture mirrors Stable Diffusion (refer to Model Configuration), the distinction lies in its training process, specifically when utilizing a different dataset and incorporating the prior preservation loss when necessary.

  • Prior Preservation Loss

    When finetuning large pretrained language models on specific tasks or text-to-image diffusion models on a small dataset, problems like language drift and decreased output variety often arise. The concept of the prior preservation loss is straightforward: it guides the model using its self-generated samples and incorporates the discrepancy between the model-predicted noise on these samples. The influence of this loss component can be adjusted using model.prior_loss_weight.

model_pred, model_pred_prior = torch.chunk(model_output, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
prior_loss = torch.nn.functional.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
loss = loss + prior_loss * self.prior_loss_weight
  • Training Dataset

    NeMo’s Dreambooth model dataset is different from other NeMo multimodal models in that it doesn’t necessitate data stored in the webdataset format. You can find a sample dataset at [MM-MODELS-DB1]. For each object you aim to integrate into the model, just place its images (typically 3-5) in a folder and specify its path in model.data.instance_dir. When training with the prior preservation loss, store images produced by the original model in a distinct folder and reference its path in model.data.regularization_dir. This process is automated in NeMo’s DreamBooth implementation.

Model Configuration

Pleaser refer to Model Configuration for how to configure Stable Diffusion. Here we show DreamBooth-specific configurations.

Prior Preservation Loss

model:
  with_prior_preservation: False
  prior_loss_weight: 0.5
  train_text_encoder: False
  restore_from_path: /ckpts/nemo-v1-5-188000-ema.nemo #This ckpt is only used to generate regularization images, thus .nemo ckpt is needed

  data:
    instance_dir: /datasets/instance_dir
    instance_prompt: a photo of a sks dog
    regularization_dir: /datasets/nemo_dogs
    regularization_prompt: a photo of a dog
    num_reg_images: 10
    num_images_per_prompt: 4
    resolution: 512
    center_crop: True
  • train_text_encoder: Dictates if the text encoder should be finetuned alongside the U-Net.

  • with_prior_preservation: Depending on its setting, this influences how the model behaves with respect to the regularization data. If set to False, both model.prior_loss_weight and model.restore_from_path will be disregarded. If set to True, the actions will differ based on the number of images present in model.data.regularization_dir:

    1. If the count is fewer than model.data.num_reg_images:

      • model.restore_from_path should be provided with a .nemo checkpoint, allowing the inference pipeline to produce regularization images.

      • model.data.num_images_per_prompt is analogous to the inference batch size and indicates the number of images generated in one pass, restricted by GPU capabilities.

      • model.regularization_prompt determines the text prompt for the inference pipeline to generate images. It’s generally a variant of model.data.instance_prompt minus the unique token.

      • Once all above parameters are satisfied, the inference pipeline will run until the required image count is achieved in the regularization directory.

    2. If the count matches or exceeds model.data.num_reg_images

      • Training will proceed without calling inference pipeline, and the parameters mentioned above will be ignored.

Training with Cached Latents

model:
    use_cached_latents: True

    data:
        num_workers: 4
        instance_dir: /datasets/instance_dir
        instance_prompt: a photo of a sks dog
        regularization_dir: /datasets/nemo_dogs
        regularization_prompt: a photo of a dog
        cached_instance_dir: #/datasets/instance_dir_cached
        cached_reg_dir: #/datasets/nemo_dogs_cached
  • use_cached_latents: Determines whether to train using online encoding or pre-cached latents.

  • cached_instance_dir:

    • If use_cached_latents is enabled and these directories with latents in .pt format are specified, training will utilize the latents rather than the original images.

    • If a cached directory isn’t provided or the number of latent files doesn’t match the original image count, the Variational Auto Encoder will compute the image latents before training, and the results will be saved on the disk.

  • cached_reg_dir: + The logic is consistent with above, contingent on the model.with_prior_preservation setting.

Reference

MM-MODELS-DB1

Google. Dreambooth. 2023. URL: https://github.com/google/dreambooth/tree/main/dataset.

MM-MODELS-DB2

Nataniel Ruiz, Yuanzhen Li, Varun Jampani, Yael Pritch, Michael Rubinstein, and Kfir Aberman. Dreambooth: fine tuning text-to-image diffusion models for subject-driven generation. 2022. URL: https://arxiv.org/abs/2208.12242.