Warning: This post is still being written and is not complete, I just uploaded a draft.
This post is basically what I learned while watching this video by DeepFindr.
Diffusion models work by destroying an input gradually until it looks like noise and then recovering the input image from that. The forward process is hardcoded, and the reverse process is trainable.
In the reverse process, the task of the model is to predict the noise that was added in each step to the input image.
We need 3 things for training a diffusion model:
In simple words, we iteratively add noise into the image where the amount of noise added per step is dependent on a parameter ()
{{< /math.inline >}}
In fancy math terms, this is how we perform the markov process:
$$
q(x_{1:T}|x_0) = \prod_{t=1}^{t=T}q(x_t|x_{t-1})
$$
(x_{1:T}) = set of samples where every subsequent item is noisier starting from the orignial image. (x_1) is the input image after adding some noise for the first time (i.e the first step) and (x_T) is the most noisy sample.
({t=1}^{t=T}q(x_t|x{t-1})) is the product of the noise samples for all values of (t) starting from 1 to (T)
{{< /math.inline >}}
Diving deeper into each noise sample {{< math.inline >}}(q(x_t|x_{t-1})){{</ math.inline >}}
First, let’s see how it’s defined:
$$
q(x_t|x_{t-1}) = N(x_t;\sqrt{1-\beta_t}x_{t-1}, \beta_tI)
$$
{{< math.inline >}}
(_t) determines the variance of the noise to be added in each step into the image.
{{< /math.inline >}}
{{< math.inline >}}
(x_{t-1}) is the previous less noisy image.
{{< /math.inline >}}
{{< math.inline >}}
() scales the mean of the noise to be added. Thus one can say that the mean of our distribution is () multiplied by (x_{t-1}) (for each pixel).
{{< /math.inline >}}
{{< math.inline >}}
(I) is the Identity
{{< /math.inline >}}
The sequence of such betas (_1), (_2)… (_t) is known as the variance schedule. They determine how much noise we’d want to add in each of the time steps.
{{< /math.inline >}}
Diving deeper into {{< math.inline >}}(){{</ math.inline >}}
Let us imagine for a second that we have an image with a single pixel and then try to understand what then equation above means:
{{< math.inline >}}(q(x|x_{t-1})) = the value of the next pixel (q of (x) given (x_{t-1}))
{{< /math.inline >}}
{{< math.inline >}}
() is the mean of the dstribution from which we would sample the next pixel
{{< /math.inline >}}
{{< math.inline >}}
() is the variance.
{{< /math.inline >}}
{{< math.inline >}} So increasing () would result in the distribution shifting to the left and also becoming more flattened (w i d e r). Kind of like the blue distribution shown below. {{< /math.inline >}}
Beta determines how fast we converge towards a mean of zero which is basically a standard gaussian distribution. Beta increases linearly with each time step (from like 0.0001
to 0.02
in 200 steps)
The neat thing about gaussians is that the sum of gaussians is also a gaussian. Which means it’s pretty easy to pre-compute the noisy image at forward time-step (t)
Now for convenience, we would make a new variable (_t = 1 - _t). Since beta was being scaled up, alpha would be scaled down on each step. You can think of alpha as the variable which determines how much information is conserved from the previous image in each time step.
The nice part is that we can just take the cumulative products of alpha (({_t})) and then we can compute the image at a forward step (t) without having to calculate all the way until step (t-1) first. This way, we can re-define the noise sampling as follows:
{{< /math.inline >}}
$$
q(x_t|x_{t_0}) = N(x_t;\sqrt{\bar{\alpha_t}}x_{0}, (1 - \bar{\alpha_t})I)
$$
{{< math.inline >}} Notice how this function is dependend only on (x_0) and not on (x_t) but it computes the noisy pixel value at time step (t). {{< /math.inline >}}
I’ll try to explain things line-by-line:
import torch
import torch.nn.functional as F
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
## Interpolates between 2 values with a pre-defined number of timesteps.
## Returns a list of Betas
return torch.linspace(start, end, timesteps)
def get_index_from_list(vals, t, x_shape):
"""
Returns a specific index t of a passed list of values vals
while considering the batch dimension.
"""
= t.shape[0]
batch_size = vals.gather(-1, t.cpu())
out return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def forward_diffusion_sample(x_0, t, device="cpu"):
## takes the input image and the timestep number t as input
## and returns it's noisy version at timestep t
= torch.randn_like(x_0)
noise = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
sqrt_alphas_cumprod_t = get_index_from_list(
sqrt_one_minus_alphas_cumprod_t
sqrt_one_minus_alphas_cumprod, t, x_0.shape
)# mean + variance
return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
+ sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)
The last line in the above snippet is equivalent to:
$$
x_t = \sqrt{\bar{\alpha_t}}x_0 + \sqrt{1-\bar{\alpha_t}}\epsilon
$$
{{< math.inline >}} Where () is the noise added into the image. {{< /math.inline >}}
Then, we pre-compute the variables for convenience:
= 1. - betas
alphas = torch.cumprod(alphas, axis=0)
alphas_cumprod = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
alphas_cumprod_prev = torch.sqrt(1.0 / alphas)
sqrt_recip_alphas = torch.sqrt(alphas_cumprod)
sqrt_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
sqrt_one_minus_alphas_cumprod = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) posterior_variance
The authors proposed to use a a U-Net, which is an image-to-image model. Here the input is the noisy image, and the output is the “noise” predicted by the model.
{{< math.inline >}} So if we have an input (x_t = x0 + ) then the model’s output would ideally be () or something close to it. {{< /math.inline >}}
Note that we also have to let the model know in which timestep the input image is, this is done with the help of positional embeddings.
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
= time.device
device = self.dim // 2
half_dim = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
embeddings return embeddings
Intuition behind SinusoidalPositionEmbeddings
(btw this also applies to transformers in general)
To understand this, first let’s take a look at how binary numbers work. You might be able to notice here that the right-most bit (red) oscillates between 0 and 1 the fastest and the subsequent ones (blue, green, yellow) oscillate in slower and slower rates.
Now if we use these values as positional encodings, the last 2 bits are good at giving fine-grained relative positions between 2 positions which are close to each other. While the larger bits are good at encoding the position at a larger scale. This way, the model would get both the overall location of each token but would also be able to distinguish between tokens which are very close to each other.
But in real life, binary numbers take up lots of space. So instead we use the something fancier: Sinusoids.
In the figure above, you can see how the first column (left) oscillates the fastest and every column after that one oscillates slower and slower. These are sinusoidal positional embeddings with a depth of 128.
I will not go into depth about the U-Net architecture itself, because it can be and has been replaced with other models for diffusion. You can just think of it as a model which takes an image as an input and is supposed to predict the noise in the input (which gets subtracted).
The important thing to note here is that we do not concatenate the time embeddings into the data, instead we add it right after passing the input through the first conv layer of a U-Net block.
class Block(nn.Module):
"""
A single building block from the U-Net
"""
def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
super().__init__()
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
if up:
self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
else:
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.bnorm1 = nn.BatchNorm2d(out_ch)
self.bnorm2 = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU()
def forward(self, x, t, ):
# First Conv
= self.bnorm1(self.relu(self.conv1(x)))
h # Time embedding
= self.relu(self.time_mlp(t))
time_emb # Extend last 2 dimensions
= time_emb[(..., ) + (None, ) * 2]
time_emb '''
This is where we add in the time embedding
combines time step and image information
'''
= h + time_emb
h # Second Conv
= self.bnorm2(self.relu(self.conv2(h)))
h # Down or Upsample
return self.transform(h)
The loss function is basically the L1 loss on the predicted noise v/s the original noise:
def get_loss(model, x_0, t):
"""
model: U-Net model through which we pass the image
x_0: original input image
t: a specific time step
functions within:
* forward_diffusion_sample: returns the noisy version of the original image and the noise that was last added for the time step t
* model.__call__: takes as input the noisy image and the time step, tries to predict the noise that was last added for the time step t
"""
= forward_diffusion_sample(x_0, t, device)
x_noisy, noise = model(x_noisy, t)
noise_pred return F.l1_loss(noise, noise_pred)
TODO: explain this
@torch.no_grad()
def sample_timestep(x, t):
"""
Calls the model to predict the noise in the image and returns
the denoised image.
Applies noise to this image, if we are not in the last step yet.
"""
= get_index_from_list(betas, t, x.shape)
betas_t = get_index_from_list(
sqrt_one_minus_alphas_cumprod_t
sqrt_one_minus_alphas_cumprod, t, x.shape
)= get_index_from_list(sqrt_recip_alphas, t, x.shape)
sqrt_recip_alphas_t
# Call model (current image - noise prediction)
= sqrt_recip_alphas_t * (
model_mean - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
x
)= get_index_from_list(posterior_variance, t, x.shape)
posterior_variance_t
if t == 0:
return model_mean
else:
= torch.randn_like(x)
noise return model_mean + torch.sqrt(posterior_variance_t) * noise
[1] - DeepFindr’s video