DPO aims to overcome the primary drawbacks of RLHF, which are it’s unstable nature and the dependence on a reward model trained on human preference data.
In this paper, they perform DPO on diffusion models by training the model to alter it’s denoising directions towards preferred images over non-preferred ones. For this objective, they use the pick-a-pic v2 dataset.
I’ll be using this script as a reference. Note that in the aforementioned paper, they finetune the entire model but in the script we see that they’re fine-tuning a LoRA adapter on the model.
The pick-a-pic v2 was collected by showing users 2 images and then asking them to pick the better one (with an optional neutral response). The image that’s not better is eliminated and is replaced by another image generated by the same prompt. The user also got an option to switch to a different prompt and start afresh.
It has the following relevant columns:
caption
: prompt that was used to generate the imagejpg_0
: first image shown to the userjpg_1
: second image shown to the userlabel_0
: is 1
if the user picked this image to be the better onelabel_1
: is 1
if the user picked this image to be the better oneOne can imagine how a single batch of images would look like (without shuffling the dataset):
{
"images": [
<caption_0_jpg_0>, // winning image
<caption_1_jpg_0>,
<caption_0_jpg_1>,
<caption_1_jpg_1> // winning image
],
"labels": [1,0,0,1]
}
This is a tiny guide to how atent diffusion models are trained.
{{< math.inline >}} Add noise () to the latent of the original image. The amount of noise is proportional to the timestep. {{< /math.inline >}}
{{< math.inline >}} Feed the noisy latent into a model. The model tries to predict the noise, which gives us (_{pred}) {{< /math.inline >}}
Finally, the generated sample is the noisy sample minus the predicted noise.
First, we extract the all the latent vectors for each image in the batch using a pre-trained VAE. Let’s call them latents
Then we add some noise to the latents.
= torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1) noise
Something interesting is going on here. Let’s take a deeper look into each part.
.chunk(2)[0]
: divides the tensor into 2 chunks (along dim 0) and then select the first one (see appendix: 1)..repeat(2, 1, 1, 1)
: repeats the tensor 2 times along dim 0
and no repeats along every other dim.This would make sure that images generated by the same caption have the same noise added to the latents.
The model takes the noisy latents and the prompt embeddings as input and predicts the noise present in the latents.
MSE loss (model_losses
) is computed between the predicted noise and the actual noise without any reduction. We get a tensor of shape batch_size
containing the MSE loss corresponding to each batch item.
model_losses
is divided into two chunks, one containing the losses for all winning samples (model_losses_w
) and another containing the losses for all losing samples (model_losses_l
).
Then we calculate a term model_diff = model_losses_w - model_losses_l
. Note that it if model_diff
is minimized, we guide the model’s denoising process towards generating winning samples and away from generating losing samples.
The diagram shown below is a visualization of model_losses_w
and model_losses_l
{{< math.inline >}} as (loss_w) and (loss_l) respectively and () as noise added to the image latents.{{< /math.inline >}}
We temporarily disable the LoRA adapters in the model and obtain the predicted noise from the original pre-trained model and calculate ref_diff
which is equivalent to model_diff
but for the original model.
The final loss that is to be minimized is calculated as follows:
= scale_term * (ref_diff - model_diff)
inside_term = -1 * F.logsigmoid(inside_term.mean()) loss
let, f(x) = − log(sigmoid(x))
limx → ( − ∞)f(x) = + ∞
limx → ( + ∞)f(x) = 0
Hence in order to mimimize the loss
, we have to maximize the inside_term
(for a visualization, see appendix: 2). This can be done by:
ref_diff
(but that is not possible since the original model is frozen)model_diff
i.e steer the model’s denoising process towards the winning samples and away from losing samplesLet’s imagine x
to be a 1d tensor: [4, 5, 6, 7]
, then the operation x.chunk(2)[0].repeat(2)
would give us the following:
>>> x = torch.tensor([4, 5, 6, 7])
>>> x.chunk(2)
(tensor([4, 5]), tensor([6, 7]))
>>> x.chunk(2)[0]
tensor([4, 5])
>>> x.chunk(2)[0].repeat(2)
tensor([4, 5, 4, 5])
This is how {{< math.inline >}} -log(sigmoid(x)) {{< /math.inline >}} looks like: