发布日期:2023/05/18
主页地址:http://myhz0606.com/article/ddpm
1 从直觉上理解DDPM
在详细推到公式之前,我们先从直觉上理解一下什么是扩散
对于常规的生成模型,如GAN,VAE,它直接从噪声数据生成图像,我们不妨记噪声数据为\(z\),其生成的图片为\(x\)
对于常规的生成模型:
学习一个解码函数(即我们需要学习的模型)p,实现 \(p(z)=x\)
\[z \stackrel{p} \longrightarrow x \tag{1} \]
常规方法只需要一次预测即能实现噪声到目标的映射,虽然速度快,但是效果不稳定。
常规生成模型的训练过程(以VAE为例)
\[x \stackrel{q} \longrightarrow z \stackrel{p} \longrightarrow \widehat{x} \tag{2} \]
对于diffusion model
它将噪声到目标的过程进行了多步拆解。不妨假设一共有\(T+1\)个时间步,第\(T\)个时间步 \(x_T\)是噪声数据,第0个时间步的输出是目标图片\(x_0\)。其过程可以表述为:
\[z = x_T \stackrel{p} \longrightarrow x_{T-1} \stackrel{p} \longrightarrow \cdots \stackrel{p} \longrightarrow x_{1} \stackrel{p} \longrightarrow x_0 \tag{3} \]
对于DDPM它采用的是一种自回归式的重建方法,每次的输入是当前的时刻及当前时刻的噪声图片。也就是说它把噪声到目标图片的生成分成了T步,这样每一次的预测相当于是对残差的预测。优势是重建效果稳定,但速度较慢。
训练整体pipeline包含两个过程
2 diffusion pipeline
2.1前置知识:
高斯分布的一些性质
(1)如果\(X \sim \mathcal{N}(\mu, \sigma^2)\),且\(a\)与\(b\)是实数,那么\(aX+b \sim \mathcal{N}(a\mu+b, (a\sigma)^2)\)
(2)如果\(X \sim \mathcal{N}(\mu(x), \sigma^2(x)) ,Y \sim \mathcal{N}(\mu(y), \sigma^2(y)),\)且\(X,Y\)是统计独立的正态随机变量,则它们的和也满足高斯分布(高斯分布可加性).
\[X+Y \sim \mathcal{N}(\mu(x)+\mu{(y), \sigma^2(x) + \sigma^2(y)}) \\ X-Y \sim \mathcal{N}(\mu(x)-\mu{(y), \sigma^2(x) + \sigma^2(y)}) \tag{4} \]
均值为\(\mu\)方差为\(\sigma\)的高斯分布的概率密度函数为
\[\begin{align*} f(x) &= \frac{1}{\sqrt{2\pi} \sigma } \exp \left ({- \frac{(x - \mu)^2}{2\sigma^2}} \right) \\ &= \frac{1}{\sqrt{2\pi} \sigma } \exp \left[ -\frac{1}{2} \left( \frac{1}{\sigma^2}x^2 - \frac{2\mu}{\sigma^2}x + \frac{\mu^2}{\sigma^2} \right ) \right] \tag{5} \end{align*} \]
2.2 加噪过程
1 前向过程:将图片数据映射为噪声
每一个时刻都要添加高斯噪声,后一个时刻都是由前一个时刻加噪声得到。(其实每一个时刻加的噪声就是训练所用的标签)。即
\[x_0 \stackrel{q} \longrightarrow x_1 \stackrel{q} \longrightarrow x_{2} \stackrel{q} \longrightarrow \cdots \stackrel{q} \longrightarrow x_{T-1} \stackrel{q} \longrightarrow x_T=z \tag{6} \]
下面我们详细来看
记\(\beta_t = 1 - \alpha_t,\beta_t\)随\(t\)的增加而增大(论文中[2]从0.0001 -> 0.02) (这是因为一开始加一点噪声就很明显,后面需要增大噪声的量才明显).DDPM将加噪声过程建模为一个马尔可夫过程\(q(x_{1:T}|x_0):= \prod \limits_{t=1}^Tq(x_t|x_{t-1}) ,\)其中\(q(x_t|x_{t-1}):=\mathcal{N}(x_t; \sqrt{\alpha_t}x_{t-1}, (1 - \alpha_t) \textbf{I})\)
\[\begin{align*} x_t &= \sqrt{\alpha_t}x_{t-1} + \sqrt{(1 - \alpha_t)}z_t \\ &= \sqrt{\alpha_t}x_{t-1} + \sqrt{\beta_t}z_t \tag{7} \end{align*} \]
\(x_t\)为在t时刻的图片,当\(t=0\)时为原图;\(z_t\)为在t时刻所加的噪声,服从标准正态分布\(z_t \sim \mathcal{N}(0, \textbf{I});\)\(\alpha_t\)是常数,是自己定义的变量;从上式可见,随着\(T\)增大,\(x_t\)越来越接近纯高斯分布.
同理:
\[x_{t-1} = \sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1 - \alpha_{t-1}}z_{t-1} \tag{8} \]
将式(8)代入式(7)可得:
\[\begin{align*} x_t &= \sqrt{\alpha_t} (\sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1 - \alpha_{t-1}}z_{t-1}) + \sqrt{1 - \alpha_t}z_t \\ &= \sqrt{\alpha_t \alpha_{t-1}}x_{t-2} + (\sqrt{\alpha_t (1 - \alpha_{t-1})} z_{t-1} + \sqrt{1 - \alpha_t}z_t) \tag{9} \end{align*} \]
由于\(z_{t-1}\)服从均值为0,方差为1的高斯分布(即标准正态分布),根据定义\(\sqrt{\alpha_t (1 - \alpha_{t-1})} z_{t-1}\)服从的是均值为0,方差为\(\alpha_t (1 - \alpha_{t-1})\)的高斯分布.即\(\sqrt{\alpha_t (1 - \alpha_{t-1})} z_{t-1} \sim \mathcal{N}(0, \alpha_t (1 - \alpha_{t-1})\textbf{I})\).同理可得\(\sqrt{1 - \alpha_t}z_t \sim \mathcal{N}(0, (1 - \alpha_t)\textbf{I})\).则**(高斯分布可加性,可以通过定义推得,不赘述)**
\[ (\sqrt{\alpha_t (1 - \alpha_{t-1})} , z_{t-1} + \sqrt{1 - \alpha_t}z_t) \sim \mathcal{N}(0, \alpha_t (1 - \alpha_{t-1}) + 1 - \alpha_t) = \mathcal{N}(0, 1 - \alpha_t \alpha_{t-1}) \tag{10} \]
我们不妨记\(\overline{z}{t-2} \sim \mathcal{N}(0, \textbf{I}),\)则\(\sqrt{1 - \alpha_t \alpha{t-1}} \overline{z}{t-2} \sim \mathcal{N}(0, (1 - \alpha_t \alpha{t-1})\textbf{I})\)则式(10)最终可改写为
\[x_t = \sqrt{\alpha_t \alpha_{t-1}} x_{t-2} + \sqrt{1 - \alpha_t \alpha_{t-1}} \overline{z}_{t-2} \tag{11} \]
通过递推,容易得到
\[\begin{align*} x_t &= \sqrt{\alpha_t \alpha_{t-1} \cdots \alpha_1} x_0 + \sqrt{1 - \alpha_t \alpha_{t-1} \dots \alpha_1} \overline{z}0 \\ &= \sqrt{\prod{i=1}^{t} {\alpha_i}}x_0 + \sqrt{1 - \prod_{i=1}^{t} {\alpha_i}} \overline {z}0 \\ &\stackrel{\mathrm{令} \overline{\alpha}{t} = \prod_{i=1}^{t} {\alpha_i}} = \sqrt{\overline{\alpha}{t}}x_0+\sqrt{1 - \overline{\alpha}{t}}\overline{z}_{0} \tag{12} \end{align*} \]
其中\(\overline{z}_{0} \sim \mathcal{N}(0, \mathrm{I}),x_0\)为原图.从式(13)可见,**我们可以从\(x_0\)得到任意时刻的\(x_t\)的分布,**而无需按照时间顺序递推!这极大提升了计算效率.
\[\begin{align*} q(x_t|x_0) &= \mathcal{N}(x_t; \mu{(x_t, t)},\sigma^2{(x_t, t)}{}\textbf{I}) \\ &= \mathcal{N}(x_t; \sqrt{\overline{\alpha}{t}}x_0,(1 - \overline{\alpha}{t})\textbf{I}) \tag{13} \end{align*} \]
⚠️加噪过程是确定的,没有模型的介入. 其目的是制作训练时标签
2.3 去噪过程
给定\(x_T\)如何求出\(x_0\)呢?直接求解是很难的,作者给出的方案是:我们可以一步一步求解.即学习一个解码函数\(p\),这个\(p\)能够知道\(x_{t}\)到\(x_{t-1}\)的映射规则.如何定义这个\(p\)是问题的关键.有了\(p\),只需从\(x_{t}\)到\(x_{t-1}\)逐步迭代,即可得出\(x_0\).
\[z = x_T \stackrel{p} \longrightarrow x_{T-1} \stackrel{p} \longrightarrow \cdots \stackrel{p} \longrightarrow x_{1} \stackrel{p} \longrightarrow x_0 \tag{14} \]
去噪过程是加噪过程的逆向.如果说加噪过程是求给定初始分布\(x_0\)求任意时刻的分布\(x_t\),即\(q(x_t|x_0)\)那么去噪过程所求的分布就是给定任意时刻的分布\(x_t\)求其初始时刻的分布\(x_0,\)即\(p(x_0|x_t)\) ,通过马尔可夫假设,可以对上述问题进行化简
\[\begin{align*} p(x_0|x_t) &= p(x_0|x1)p(x1|x2)\cdots p(x_{t-1}| x_t) \\ &= \prod_{i=0}^{t-1}{p(x_i|x_{i+1})} \tag{15} \end{align*} \]
如何求\({p(x_{t-1}|x_{t})}\)呢?前面的加噪过程我们大力气推到出了\({q(x_{t}|x_{t-1})},\)我们可以通过贝叶斯公式把它利用起来
\[p(x_{t-1}|x_t) = \frac{p(x_{t}|x_{t-1})p(x_{t-1})}{p(x_t)} \tag{16} \]
⚠️这里的(去噪)\(p\)和上面的(加噪)\(q\)只是对分布的一种符号记法。
有了式(17)还是一头雾水,\(p(x_t)\)和\(p(x_{t-1})\)都不知道啊!该怎么办呢?这就要借助模型的威力了.下面来看如何构建我们的模型.
延续加噪过程的推导\(p(x_t|x_0)\)和\(p(x_{t-1}|x_0)\)我们是可以知道的.因此若我们知道初始分布\(x_0\),则
\[\begin{align*} p(x_{t-1}|x_t,x_0) &= \frac{p(x_{t}|x_{t-1}, x_0)p(x_{t-1}|x_0)}{p(x_t|x_0)} &(17) \\ &= \frac{\mathcal{N}(x_t; \sqrt{\alpha_t}x_{t-1}, (1 - \alpha_t) \textbf{I} ) \mathcal{N}(x_{t-1}; \sqrt{\overline{\alpha}{t-1}}x_0,(1 - \overline{\alpha}{t-1}) \textbf{I})} { \mathcal{N}(x_t; \sqrt{\overline{\alpha}{t}}x_0,(1 - \overline{\alpha}{t}) \textbf{I} )} &(18) \\ &\stackrel{将式(5)代入} \propto \frac{ \exp \left ({- \frac{(x_t - \sqrt{\alpha_t}x_{t-1} )^2}{2 (1 - \alpha_t)}} \right) \exp \left ({- \frac{(x_{t-1} - \sqrt{\overline{\alpha}{t-1}}x_0 )^2}{2 (1 - \overline{\alpha}{t-1})}} \right) } { \exp \left ({- \frac{(x_{t} - \sqrt{\overline{\alpha}{t}}x_0 )^2}{2 (1 - \overline{\alpha}{t})}} \right) } &(19) \\ &= \exp \left [-\frac{1}{2} \left ( \frac{(x_t - \sqrt{\alpha_t}x_{t-1} )^2}{1 - \alpha_t} + \frac{(x_{t-1} - \sqrt{\overline{\alpha}{t-1}}x_0 )^2}{1 - \overline{\alpha}{t-1}} - \frac{(x_{t} - \sqrt{\overline{\alpha}{t}}x_0 )^2}{1 - \overline{\alpha}{t}} \right) \right] &(20) \\ &= \exp \left [ -\frac{1}{2} \left( \left( \frac{\alpha_t}{1-\alpha_t} + \frac{1}{1 - \overline{\alpha}{t-1}} \right)x^2{t-1} - \left ( \frac{2\sqrt{\overline{\alpha_{t}}}}{1 - \alpha_t}x_t + \frac{2 \sqrt{\overline{\alpha}{t-1}}} {1 - \overline{\alpha}{t-1} }x_0 \right)x_{t-1} + C(x_t, x_0) \right) \right] &(21) \end{align*} \]
结合高斯分布的定义(6)来看式(22),不难发现\(p(x_{t-1}|x_t,x_0)\)也是服从高斯分布的.并且结合式(6)我们可以求出其方差和均值
⚠️式17做了一个近似\(p(x_t|x_{t-1}, x_0) =p(x_t| x_{t-1}),\)能做这个近似原因是一阶马尔科夫假设,当前时间点只依赖前一个时刻的时间点.
\[\begin{align*} \frac{1}{\sigma_2} &= \frac{\alpha_t}{1-\alpha_t} + \frac{1}{1 - \overline{\alpha}{t-1}} &(22) \\ \frac{2\mu}{\sigma^2} &= \frac{2\sqrt{\overline{\alpha{t}}}}{1 - \alpha_t}x_t + \frac{2 \sqrt{\overline{\alpha}{t-1}}} {1 - \overline{\alpha}{t-1} }x_0 &(23) \end{align*} \]
可以求得:
\[\begin{align*} \sigma^2 &= \frac{1 - \overline{\alpha}{t-1}}{1 - \overline{\alpha}{t}} (1 - \alpha_t) \\ \mu &= \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}_{t-1})} {1 - \overline{\alpha}t}x_t + \frac{\sqrt{\overline{\alpha}{t-1}} (1 - \alpha_t) }{1 - \overline{\alpha}_t}x_0 \tag{24} \end{align*} \]
通过上式,我们可得
\[p(x_{t-1}|x_t,x_0) = \mathcal{N}(x_{t-1}; \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}{t-1})} {1 - \overline{\alpha}t}x_t + \frac{\sqrt{\overline{\alpha}{t-1}} (1 - \alpha_t) }{1 - \overline{\alpha}t}x_0 , (\frac{1 - \overline{\alpha}{t-1}}{1 - \overline{\alpha}{t}} (1 - \alpha_t)) \textbf{I}) \tag{25} \]
该式是真实的条件分布.我们目标是让模型学到的条件分布\(p_\theta(x_{t-1}|x_t)\)尽可能的接近真实的条件分布\(p(x_{t-1}|x_t, x_0).\)从上式可以看到方差是个固定量,那么我们要做的就是让\(p(x_{t-1}|x_t, x_0)\)与\(p_\theta(x_{t-1}|x_t)\)的均值尽可能的对齐,即
(这个结论也可以通过最小化上述两个分布的KL散度推得)
\[\mathrm{arg} \mathop{min}\theta \parallel u(x_0, x_t), u\theta(x_t, t) \parallel \tag{26} \]
下面的问题变为:如何构造\(u_\theta(x_t, t)\)来使我们的优化尽可能的简单
我们注意到\(\mu(x_0, x_t)与\mu_\theta(x_t, t)\)都是关于\(x_t\)的函数,不妨让他们的\(x_t\)保持一致,则可将\(\mu_\theta(x_t, t)\)写成
\[\mu_\theta(x_t, t) = \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}_{t-1})} {1 - \overline{\alpha}t}x_t + \frac{\sqrt{\overline{\alpha}{t-1}} (1 - \alpha_t) }{1 - \overline{\alpha}t} f\theta(x_t, t) \tag{27} \]
\(f_\theta(x_t, t)\)是我们需要训练的模型.这样对齐均值的问题就转化成了: **给定\(x_t, t\)来预测原始图片输入\(x_0.\)**根据上文的加噪过程,我们可以很容易制造训练所需的数据对! (Dalle2的训练采用的是这个方式).事情到这里就结束了吗?
DDPM作者表示直接从\(x_t\)到\(x_0\)的预测数据跨度太大了,且效果一般.我们可以将式(12)做一下变形
\[\begin{align*} x_t &= \sqrt{\overline{\alpha}{t}}x_0+\sqrt{1 - \overline{\alpha}{t}}\overline{z}{0} \\ x_0 &= \frac{1}{\sqrt{\overline{\alpha}{t}}}(x_t - \sqrt{1 - \overline{\alpha}{t}}\overline{z}{0}) \tag{28} \end{align*} \]
代入到式(24)中
\[\begin{align*} \mu &= \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}{t-1})} {1 - \overline{\alpha}t}x_t + \frac{\sqrt{\overline{\alpha}{t-1}} (1 - \alpha_t) }{1 - \overline{\alpha}t} \frac{1}{\sqrt{\overline{a}{t}}}(x_t - \sqrt{1 - \overline{a}{t}}\overline{z}{0}) \\ &= \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}{t-1})} {1 - \overline{\alpha}t}x_t + \frac{(1 - \alpha_t) }{1 - \overline{\alpha}t} \frac{1}{\sqrt{\alpha}{t}}(x_t - \sqrt{1 - \overline{\alpha}{t}}\overline{z}{0}) \\ &\stackrel{合并x_t} = \frac{\alpha_t(1 - \overline{\alpha}{t-1}) + (1 - \alpha_t) }{\sqrt{\alpha}_t (1 - \overline{\alpha}_t)}x_t - \frac{\sqrt{1 - \overline{\alpha}_t}(1 - \alpha_t) }{\sqrt{\alpha_t}(1 - \overline{\alpha}_t)}\overline{z}_0 \\ &= \frac{1 - \overline{\alpha}_t}{\sqrt{\alpha}_t (1 - \overline{\alpha}_t)}x_t - \frac{1 - \alpha_t }{\sqrt{\alpha_t}\sqrt{1 - \overline{\alpha}_t}}\overline{z}_0 \\ &= \frac{1}{\sqrt{\alpha}_t}x_t - \frac{1 - \alpha_t }{\sqrt{\alpha_t}\sqrt{1 - \overline{\alpha}_t}}\overline{z}_0 \tag{29} \end{align*} \]
经过这次化简,我们将\(\mu{(x_0, x_t)} \Rightarrow \mu{(x_t, \overline{z}_0)},\)其中\(\overline{z}_0 \sim \mathcal{N}(0, \textbf{I}),\)可以将式(29)转变为
\[\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} x_t - \frac{1 - \alpha_t }{\sqrt{\alpha_t}\sqrt{1 - \overline{\alpha}t}}f\theta(x_t, t) \tag{30} \]
此时对齐均值的问题就转化成:给定\(x_t, t\)预测\(x_t\)加入的噪声\(\overline{z}_0\), 也就是说我们的模型预测的是噪声\(f_\theta{(x_t, t)} = \epsilon_{\theta}(x_t, t) \simeq \overline{z}_0\)
2.3.1 训练与采样过程
训练的目标就是这所有时刻两个噪声的差异的期望越小越好(用MSE或L1-loss).
\[\mathbb{E}{t \sim T } \parallel \epsilon - \epsilon{\theta}(x_t, t)\parallel_2 ^2 \tag{31} \]
下图为论文提供的训练和采样过程
2.3.2 采样过程
通过以上讨论,我们推导出\(p_\theta(x_{t-1}|x_t)\)高斯分布的均值和方差.\(p_\theta(x_{t-1}|x_t)=\mathcal{N}(x_{t-1}; \mu_{\theta}(x_t, t), \sigma^2(t) \textbf{I})\),根据文献[1]从一个高斯分布中采样一个随机变量可用一个重参数化技巧进行近似
\[\begin{align*} x_{t-1} &= \mu_{\theta}(x_t, t) + \sigma(t) \epsilon,其中 \epsilon \in \mathcal{N}(\epsilon; 0, \textbf{I}) \\ & = \frac{1}{\sqrt{\alpha_t}} (x_t - \frac{1 - \alpha_t }{\sqrt{1 - \overline{\alpha}t}}\epsilon\theta(x_t, t)) + \sigma(t) \epsilon &\tag{32} \end{align*} \]
式(32)和论文给出的采样递推公式一致.
至此,已完成DDPM整体的pipeline.
还没想明白的点,为什么不能根据(7)的变形来进行采样计算呢?
\[x_{t-1} = \frac{1}{\sqrt{\alpha_t}}x_t - \sqrt{\frac{1 - \alpha_t}{\alpha_t}} f_\theta(x_t, t) \tag{33} \]
3 从代码理解训练&预测过程
3.1 训练过程
参考代码仓库: https://github.com/lucidrains/denoising-diffusion-pytorch/tree/main/denoising_diffusion_pytorch
已知项: 我们假定有一批N张图片\(\{x_i |i=1, 2, \cdots, N\}\)
第一步 : 随机采样K组成batch,如\(\mathrm{x\_start}= \{ x_k|k=1,2, \cdots, K \}, \mathrm{Shape}(\mathrm{x\_start}) = (K, C, H, W)\)
第二步: 随机采样一些时间步
python
t = torch.randint(0, self.num_timesteps, (b,), device=device).long() # 随机采样时间步
第三步: 随机采样噪声
python
noise = default(noise, lambda: torch.randn_like(x_start)) # 基于高斯分布采样噪声
第四步 : 计算\(\mathrm{x\_start}\)在所采样的时间步的输出\(x_T\)(即加噪声).(根据公式12)
python
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
betas = linear_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def q_sample(x_start, t, noise=None):
"""
\begin{eqnarray}
x_t &=& \sqrt{\alpha_t}x_{t-1} + \sqrt{(1 - \alpha_t)}z_t \nonumber \\
&=& \sqrt{\alpha_t}x_{t-1} + \sqrt{\beta_t}z_t
\end{eqnarray}
"""
return (
extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
x = q_sample(x_start = x_start, t = t, noise = noise) # 这就是x0在时间步T的输出
第五步 : 预测噪声.输入\(x_T,t\)到噪声预测模型,来预测此时的噪声\(\hat{z}t = \epsilon\theta(x_T, t)\).论文用到的模型结构是Unet,与传统Unet的输入有所不同的是增加了一个时间步的输入.
python
model_out = self.model(x, t, x_self_cond=None) # 预测噪声
这里面有一个需要注意的点:模型是如何对时间步进行编码并使用的
- 首先会对时间步进行一个编码,将其变为一个向量,以正弦编码为例
python
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
"""
Args:
x (Tensor), shape like (B,)
"""
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# 时间步的编码pipeline如下,本质就是将一个常数映射为一个向量
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
- 将时间步的embedding嵌入到Unet的block中,使模型能够学习到时间步的信息
python
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift # 将时间向量一分为2,一份用于提升幅值,一份用于修改相位
x = self.act(x)
return x
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb = None):
scale_shift = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
第六步:计算损失,反向传播.计算预测的噪声与实际的噪声的损失,损失函数可以是L1或mse
python
@property
def loss_fn(self):
if self.loss_type == 'l1':
return F.l1_loss
elif self.loss_type == 'l2':
return F.mse_loss
else:
raise ValueError(f'invalid loss type {self.loss_type}')
通过不断迭代上述6步即可完成模型的训练
3.2采样过程
第一步:随机从高斯分布采样一张噪声图片,并给定采样时间步
python
img = torch.randn(shape, device=device)
第二步: 根据预测的当前时间步的噪声,通过公式计算当前时间步的均值和方差
python
posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) # 式(24)x_0的系数
posterior_mean_coef = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) # 式(24) x_t的系数
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
) # 求出此时的均值
posterior_variance = extract(self.posterior_variance, t, x_t.shape) # 求出此时的方差
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) # 对方差取对数,可能为了数值稳定性
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
preds = self.model_predictions(x, t, x_self_cond) # 预测噪声
x_start = preds.pred_x_start # 模型预测的是在x_t时间步噪声,x_start是根据公式(12)求
if clip_denoised:
x_start.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
return model_mean, posterior_variance, posterior_log_variance, x_start
第三步: 根据公式(32)计算得到前一个时刻图片\(x_{t-1}\)
python
@torch.no_grad()
def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
b, *_, device = *x.shape, x.device
batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised) # 计算当前分布的均值和方差
noise = torch.randn_like(x) if t > 0 else 0. # 从高斯分布采样噪声
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise # 根据
return pred_img, x_start
通过迭代以上三步,直至\(T=0\)完成采样.
思考和讨论
DDPM区别与传统的VAE与GAN采用了一种新的范式实现了更高质量的图像生成.但实践发现,需要较大的采样步数才能得到较好的生成结果.由于其采样过程是一个马尔可夫的推理过程,导致会有较大的耗时.后续工作如DDIM针对该特性做了优化,数十倍降低采样所用时间。