Direct Preference Optimization for Diffusion Models

DPO aims to overcome the primary drawbacks of RLHF, which are it’s unstable nature and the dependence on a reward model trained on human preference data.

In this paper, they perform DPO on diffusion models by training the model to alter it’s denoising directions towards preferred images over non-preferred ones. For this objective, they use the pick-a-pic v2 dataset.

I’ll be using this script as a reference. Note that in the aforementioned paper, they finetune the entire model but in the script we see that they’re fine-tuning a LoRA adapter on the model.

Dataset

The pick-a-pic v2 was collected by showing users 2 images and then asking them to pick the better one (with an optional neutral response). The image that’s not better is eliminated and is replaced by another image generated by the same prompt. The user also got an option to switch to a different prompt and start afresh.

It has the following relevant columns:

  1. caption: prompt that was used to generate the image
  2. jpg_0: first image shown to the user
  3. jpg_1: second image shown to the user
  4. label_0: is 1 if the user picked this image to be the better one
  5. label_1: is 1 if the user picked this image to be the better one

One can imagine how a single batch of images would look like (without shuffling the dataset):

{
    "images": [
        <caption_0_jpg_0>,   // winning image
        <caption_1_jpg_0>,
        <caption_0_jpg_1>,
        <caption_1_jpg_1>   // winning image
    ],
    "labels": [1,0,0,1]
}

Basics

This is a tiny guide to how atent diffusion models are trained.

Diffusion Breakdown

{{< math.inline >}} Add noise () to the latent of the original image. The amount of noise is proportional to the timestep. {{< /math.inline >}}

{{< math.inline >}} Feed the noisy latent into a model. The model tries to predict the noise, which gives us (_{pred}) {{< /math.inline >}}

Finally, the generated sample is the noisy sample minus the predicted noise.

Single Training Step

  1. First, we extract the all the latent vectors for each image in the batch using a pre-trained VAE. Let’s call them latents

  2. Then we add some noise to the latents.

    noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)

    Something interesting is going on here. Let’s take a deeper look into each part.

    This would make sure that images generated by the same caption have the same noise added to the latents.

  3. The model takes the noisy latents and the prompt embeddings as input and predicts the noise present in the latents.

  4. MSE loss (model_losses) is computed between the predicted noise and the actual noise without any reduction. We get a tensor of shape batch_size containing the MSE loss corresponding to each batch item.

  5. model_losses is divided into two chunks, one containing the losses for all winning samples (model_losses_w) and another containing the losses for all losing samples (model_losses_l).

  6. Then we calculate a term model_diff = model_losses_w - model_losses_l. Note that it if model_diff is minimized, we guide the model’s denoising process towards generating winning samples and away from generating losing samples.

The diagram shown below is a visualization of model_losses_w and model_losses_l {{< math.inline >}} as (loss_w) and (loss_l) respectively and () as noise added to the image latents.{{< /math.inline >}}

Diffusion Breakdown
  1. We temporarily disable the LoRA adapters in the model and obtain the predicted noise from the original pre-trained model and calculate ref_diff which is equivalent to model_diff but for the original model.

  2. The final loss that is to be minimized is calculated as follows:

    inside_term = scale_term * (ref_diff - model_diff)
    loss = -1 * F.logsigmoid(inside_term.mean())


    let, f(x) =  − log(sigmoid(x))


    limx → ( − ∞)f(x) =  + ∞


    limx → ( + ∞)f(x) = 0

    Hence in order to mimimize the loss, we have to maximize the inside_term (for a visualization, see appendix: 2). This can be done by:

    1. Maximizing ref_diff (but that is not possible since the original model is frozen)
    2. Minimizing model_diff i.e steer the model’s denoising process towards the winning samples and away from losing samples
Diffusion Breakdown

Appendix

  1. Let’s imagine x to be a 1d tensor: [4, 5, 6, 7], then the operation x.chunk(2)[0].repeat(2) would give us the following:

    >>> x = torch.tensor([4, 5, 6, 7])
    >>> x.chunk(2)
    (tensor([4, 5]), tensor([6, 7]))
    >>> x.chunk(2)[0]
    tensor([4, 5])
    >>> x.chunk(2)[0].repeat(2)
    tensor([4, 5, 4, 5])
  2. This is how {{< math.inline >}} -log(sigmoid(x)) {{< /math.inline >}} looks like: