这篇文章最初发表在 NVIDIA 技术博客上。
借助互联网级数据,AI 生成内容的计算需求显著增加,数据中心在数周或数月内全力运行单个模型,更不用说通常作为服务提供的高生成推理成本。在这种情况下,牺牲性能的次优算法设计是一个代价高昂的错误。
近期,AI 生成的图像、视频和音频内容取得了很大进展,降噪扩散 ------ 一种以迭代方式将随机噪声塑造成新数据样本的技术。我们的团队最近发表的一篇研究论文 《阐明基于扩散的生成模型的设计空间》 获得了 NeurIPS 2022 杰出论文奖,该论文识别出了文档中看似复杂的方法背后的简单核心机制。从对基础知识的清晰认识开始,我们能够发现在质量和计算效率方面的先进实践。
降噪扩散
降噪是指从图像中消除传感器噪声或从录音中消除声等操作。本文将使用图像作为运行示例,但该过程也适用于许多其他领域。此任务非常适合卷积神经网络。
这与生成新图像有什么关系?想象一下,图像上有大量噪点。确实,原始图像丢失了这么多。是否可以使用降噪器来揭示某些随机可能隐藏在所有噪音下的图像?令人惊讶的是,答案是肯定的。
这是降噪扩散的一个简单本质:首先,随机绘制一张纯白色噪声的图像,然后通过反复将其输入到神经降噪器,在噪声级别(例如一次 2%)消除噪声。逐渐地,从噪声下方出现随机干净的图像。生成内容(猫和狗的图片?英语口语短语的音频波形?驾驶视频片段?)的分布由降噪器网络训练时使用的数据集确定。
图 1.降噪漫射通过重复降噪揭示纯噪声中的新图像
以下代码是对如何在假设神经网络函数的情况下实现这一想法的初步猜测denoise
可用。
# start with an image of pure large-magnitude noise |
sigma = 80 # initial noise level |
x = sigma * torch.randn(img_shape) |
for step in range``(``256``): |
# keep 98% of current noisy image, and mix in 2% of denoising |
x = 0.98 * x + 0.02 * denoise(x, sigma) |
# keep track of current noise level |
sigma *``= 0.98 |
如果您看过该领域的代码库或科学论文(充满了方程页),您可能会惊讶地发现,这个几乎微不足道的代码实际上是一个理论上有效的实现,称为概率流普通微分方程求解器.尽管此代码段并非最佳,但令人惊讶的是,它体现了论文中解释的许多关键良好实践。该团队的顶级最终采样器本质上只是几行。
该函数如何denoise
?其核心同样非常简单:降噪器必须输出可能隐藏在噪声下的所有可能的清晰图像的模糊平均值。各种噪声级别下的预期输出可能如图 2 所示。
图 2.不同噪声级别下的理想降噪器输出示例。在高噪声级别下,图像细节不确定,且输出保持模糊
使用基本损失(即输出与清晰目标之间的均方误差)训练降噪器网络(通常为 U-Net)可以精确地达到此结果。旨在提高输出清晰度的更精细的损失实际上是有害的,并且违反了理论。请记住,即使任务在概念上很简单,大多数现有的降噪器也不是专门针对它进行训练的。
许多明显的数学复杂性都源自于 该理论的原理。该理论可以从各种形式中构建,其中最流行的两种是 Markov 链和随机微分方程。虽然每种方法都可以归结为使用经过训练的降噪器的降噪循环,但它们为不同的实际实现开辟了广阔且令人困惑的空间,并为做出错误选择提供了机会的雷区。
该论文回顾了数学复杂性的各个层,直接揭示了标准化框架中易于分析的实际设计选择。
本文通过可视化和代码介绍了团队的主要发现和直觉。我们将介绍三个主题:
- 直观地概述了理论降噪扩散背后的原理
- 设计选择采样(在您已经拥有经过训练的降噪器时用于生成图像)
- 设计选择训练降噪器
是什么让扩散发挥作用?
首先,本节将回顾基础知识,并构建理论框架来证明这段简单代码的正确性。我们在微分方程的框架下找到了大部分见解,这一框架最初在通过随机微分方程进行基于分数的生成建模中被提出。虽然方程和数学概念可能看起来复杂,但它们对于理解核心概念并不是必要的。偶尔提及这些概念是有益的,因为它们往往只是用另一种语言描述代码中所完成的具体事务。
想象一张 RGB 图像 x ,其形状为 [3, 64, 64]
。首先,我们考虑一种简单的破坏方向,即通过逐渐在图像上添加噪点来调整图像。(当然,这与最终目标相反。)
for step in range``(``1000``): |
x = x + 0.1 * torch.randn_like(x) |
这实际上是(适当地斜视)与简单 SDE 对应的随机微分方程 (SDE) 求解器 表示图像的更改 x 在短时间步长内为随机白噪点。在这里 解决 只是意味着模拟 SDE 描述的过程的特定随机数值实现。
微分方程的一个优点是它们具有直观的几何解释。您可以将这一过程视为图像在像素值空间中进行类似布朗运动(著名的 Brownian 运动或 Wiener 过程)的随机漫步。如果您将x视为仅仅是一个数字("单像素图像"),那么您可以根据下图来描绘其变化过程。真实情况与此完全相同,只不过是在更高维度上进行,因此无法在二维显示器上直观显示。
图 3.渐进式噪点加法是在像素值空间中随机移动
通过使用许多不同的起始图像和随机路径研究这种演变,您开始看到混乱中的一些顺序。想象一下,这些弯曲的路径堆叠在一起。它们平均会随着时间的推移而改变形状。
左边缘的复杂数据模式(您可以隐喻性地想象分别对应于猫和狗图像的两个峰值)逐渐混合并简化为右边缘的无特征 Blob.这是无处不在的正态分布或纯白色噪声。
图 4.所有数据集图像绘制的随机路径建立了随时间演变的密度
高级目标(生成建模)是以某种方式找到从图 4 左侧真实隐藏数据分布中对新图像进行采样的技巧,即实际的新图像,可能位于数据集中,但 .您可以轻松地从右侧的纯噪声状态中进行采样,使用randn
.是否可以反向运行上述降噪过程,以随机采样干净的图像结束(图 5)?
图 5.从噪声开始到随机生成图像结束的反向随机行走
遵循从右边缘开始的随机路径后,有什么能保证左边缘有正确的图像,而不仅仅是更多的噪点?需要额外的力量来将图像朝着每个步骤的数据轻轻拉取。
SDE 的理论提供了一个很好的解决方案。在不深入探讨技术细节的情况下,它确实可以反转时间方向,这样做就会自动为受欢迎的数据吸引力引入一个额外的术语。该力将噪声图像拉向均方优化降噪。这可以通过经过训练的神经网络进行估计(此处,sigma 是当前的噪声级别):
您甚至可以调整这两个术语的权重,前提是您注意保持降噪总速率不变。将此想法带到消除噪音的极限只会导致完全确定性的普通微分方程(ODE),完全没有随机分量。然后演变遵循平滑的轨迹,图像只是从固定噪声下方逐渐消失(图 6)。
图 6.确定性常微分方程引起的平滑演进
请注意图 6 中的曲线轨迹如何将右边缘的初始随机噪声连接到左边缘唯一生成的图像。事实上,ODE 为每个初始噪声建立了不同的轨迹。将这些曲线想象成推动我们的图像的流体的流线。在生成过程中,任务只是简单地从一开始就尽可能准确地遵循流线。从右侧的随机点开始,在每个步骤中,公式(实际上是降噪器网络)都会显示流线指向当前图像的位置。在其方向上英寸一点并重复。这就是生成过程。
图 7 显示求解器的每个步骤都会将时间向后推进选定的数量 (dt),并参考 ODE 公式(以及降噪器网络),以确定如何在时间步长内更改图像。
图 7.求解器的每个步骤都会将时间向后推进一些选定的数量 (dt),并参考 ODE 公式以确定如何在时间步长内更改图像
后续部分仅分析确定性版本,因为随机性模糊了确定性图片提供的几何见解。尽管随机性在适当调整后具有有益的纠错特性,但其使用起来相对繁琐,可以视为一种辅助手段。有关更多详细信息,请参阅阐明基于扩散的生成模型的设计空间。
用于采样以生成图像的设计选择
正如简介中所述,是决定性能的细节。关键的困难在于网络给出的步进方向是有效的仅在当前噪声级别附近.尝试在不停止重新评估的情况下立即减少过多的噪点会导致在图像中添加不应该存在的内容。这表现为不同程度的图像质量降低:难以形容的模糊和颗粒化、颜色和强度伪影、面部失真和缺乏一致性以及其他更高级别的细节等。
在 1D 可视化中,这对应于从起始流线开始的步长,如图 8 所示。请注意箭头(表示可能采取的步长)与曲线之间的空隙。
图 8.线性步长(直箭头)与真实曲线流线的近似值可能较差
常见的强力解决方案是简单地执行大量极短的步骤,以避免被丢弃。但是,这很昂贵,因为每个步骤都需要对降噪器网络进行完整的评估。这就像爬行而不是运行:安全但缓慢。
我们的采样器设计在不影响质量的情况下大幅减少了所需步骤的数量。策略有三个方面:
- 设计 ODE,使其流线尽可能笔直,因此易于遵循 (噪音调度)
- 确定哪些噪声级别仍需额外小心步进 (时间步长离散)
- 采取更明智的步骤,充分利用每个(高阶求解器)
理顺流程以减少步骤
问题的关键在于流线的曲率。如果它们是直线,就很容易遵循。可以采取一个漫长的直线步骤,一直到噪声级别 0,而不必担心从曲线上掉下来。实际上,一些曲率是不可避免地内置在设置中的。能否减少?
事实证明,上一节中开发的理论在这方面做出了一些糟糕的选择。例如,您可以通过指定不同的噪声表来构建不同版本的 ODE.回想一下,1D 可视化是通过在每个步骤中添加相同数量的噪声来构建的。如果以不同的时变速率添加,则会在不同的时间(不同的时间表)达到每个噪声级别。这相当于延伸和压缩时间轴。
图 9 显示了几个不同的 ODE,这些 ODE 是由不同的噪声表选项引起的。
图 9.不同的噪声表会导致不同的流线曲率。在某些表中,线性步长与曲线的近似值优于其他表
请注意,这会产生重构流线的副作用。事实上,这些线在其中一个调度表中几乎是笔直的。这确实是团队所主张的。表示步骤的箭头现在几乎与曲线完全对齐。因此,与其他选择相比,可以减少很多步骤(图 10)。
图 10.我们团队选择的降噪计划。虽然低噪声级别(左边缘)仍然存在一些曲率,但在演进的大部分时间里,流线几乎都是笔直的
图 10 显示了随着时间的推移,噪音水平呈线性增长的进度。与先前的固定速率加法示例相比,噪音水平最初快速增长,但随后放慢。换言之,时间成为噪音水平的代名词。在不深入讨论此处的技术细节的情况下,这一特定选择提供了非常直观的求解器算法。这是我们论文中的算法 1,没有可选的第 6 行到第 8 行,使用了建议的时间表,并在经过一些整理之后:
# a (poor) placeholder example time discretization |
timesteps = np.linspace(``80``, 0``, num_steps) |
# sample an image of random noise at first noise level |
x = torch.randn(img_shape) * timesteps[``0``] |
# iterate through pairs of adjacent noise levels |
for t_curr, t_next in zip``(timesteps[:``-``1``], timesteps[``1``:]): |
# fraction of noise we keep in this iteration |
blend = t_next / t_curr |
# mix in the denoised image |
x = blend * x + (``1``-``blend) * denoise(x, t_curr) |
代码仅对简介中提到的内容进行了轻微的泛化处理,其实并没有比这更简单的了。这个算法如此简单,以至于人们会好奇,为什么它没有在 2015 年以启发式算法的形式被提出 ------ 也许当时这个想法看起来太过荒谬,不切实际。顺便提一下,2015 年的论文讨论了降噪扩散,使用无平衡热力学的深度无监督学习,但是措辞包含了复杂的数学术语。多年来,其潜在价值一直未受到足够的重视。
在低噪音水平下小心步进
这清楚地凸显了另一种设计选择,在大多数处理中,这种选择是模糊的,并且与噪点安排纠结在一起:时间步长的选择。先前代码片段中使用的线性间距实际上是一个糟糕的选择。从经验(以及根据自然图像统计推理)来看,很明显,细节在低噪点附近显示得更快。在 1D 可视化中,图形右侧的大部分几乎没有发生,但随后流线突然转向左侧的两个池中的一个。这意味着在高噪点级别下可以实现长步长,但在接近低噪点级别时必须放慢速度(图 11)。
图 11.高噪声级别下的时间步长和低噪声级别下的时间步长
我们的论文以经验为基础,研究了在低噪声级别与高噪声级别下,步长的相对长度。以下代码片段对时间步长作出了简单而可靠的修改。大致上,将其中的数字提高到 7 的(注意将其扩展到 0 到 80 的原始范围)。这严重偏移了低噪声级别的步长:
sigma_max = 80 |
sigma_min = 0.002 # leave a microscopic bit of noise for stability |
rho = 7 |
step_indices = torch.arange(num_steps) |
timesteps = (sigma_max *``* (``1 / rho) \ |
+ step_indices / (num_steps - 1``) \ |
* (sigma_min *``* (``1 / rho) - sigma_max *``* (``1 / rho))) *``* rho |
高阶求解器,可实现更准确的步骤
ODE 视点支持使用更精致的高阶求解器,该求解器本质上采用曲线而不是线性步骤。这在尝试遵循曲线流线时显得尤为有利。尽管如此,其优势并不总是明显,因为估计局部曲率需要额外的神经网络评估。团队测试了一系列方法,并一致认为所谓的二阶 Heun 方案是最佳选择(见图 12)。这需要在代码中添加几行(详见阐明基于扩散的生成模型设计空间的算法 1),虽然每次迭代的成本翻倍,但所需的迭代次数却减少到了一小部分。
Heun 步骤具有很好的几何解释和代码中的简单实现。像以前一样采取初步步骤,然后采取第二步,从着陆点返回一半。注意最终校正步骤如何比原始步骤更接近实际流线(图 12)。
图 12.Heun 步骤以几何图形显示
结合所有这些改进,现在只需对降噪器进行 30 到 80 次评估即可,而在之前的大多数工作中,评估降噪器的次数是 250 到 1000 次。
用于训练降噪器的设计选择
这是一个流畅高效的降噪步骤链。到目前为止,我们假设每个步骤都可以称为易于训练的降噪器denoise (x, sigma)
输入噪点图像和指示其噪点级别的数字。但如何对其进行 参数化和训练以获得最佳结果?
理论上有效的此类网络训练的最基本形式(此处 PyTorch 模块实例化为denoise
)看起来类似于以下内容:
# WARNING: this code illustrates poor choices across the board! |
for clean_image in training_data: # we'll ignore minibatching for brevity |
# pick a random noise level to train at |
sigma = np.random.uniform(``0``, 80``) |
# add noise with this level |
noisy_image = clean_image + sigma * torch.randn_like(clean_image) |
# feed to network under training |
denoised_image = denoise(noisy_image, sigma) |
# compute mean square loss |
loss = (denoised_image - clean_image).square().``sum``() |
# ... plus the usual backpropagation and parameter updates |
该理论要求使用白噪点和均方损失,并触及打算用于采样的所有噪声级别。在这些限制范围内,可以很大程度上重新排列计算。以下小节确定并解决了本代码中的每个严重实际问题。
请注意,网络架构本身将不会得到解决。本次讨论在很大程度上是正交的,与层数量、形状和大小、注意力或转换器的使用等无关。对于论文中的所有结果,都采用了之前工作中的网络架构。
网络友好型数值大小
根据经验,我们已将这些示例中的最大噪点级别选为足够大的数字,以完全淹没图像。因此,有时会向降噪器馈入像素值大约在 -1 到 1 范围内的图像(当噪点级别非常低时),有时还会馈入超出 -- 100 到 100 范围的图像。这会引发红旗,因为众所周知,如果神经网络的输入在不同示例之间的规模上存在巨大差异,则会受到不稳定训练和最终性能不佳的影响。 来标准化规模。
有些人通过修改 ODE 本身来解决这一问题,例如,采样过程使噪声图像保持在固定幅度范围内,而不是允许其随着时间的推移而扩展(即所谓的保持差异 扩展时间表)。遗憾的是,这再次扭曲了流线,破坏了上一节中介绍的拉直的好处。
下面是一个不存在此类数值缺点的简单解决方案。噪声级别是已知的,因此只需将噪声图像扩展到标准大小,然后再将其输入到网络中即可。它将通过训练自动适应不同的比例约定,但会消除有问题的范围变化。
要做到这一点,最好的方法是保持denoise
从外部调用者(ODE 求解器和训练循环)的角度来看没有变化,但在内部改变其利用网络的方式。将实际的原始网络层隔离到自己的黑子模块中net
并使用大小管理代码("preconditioning")将其包装在denoise
:
sigma_data = 0.5 # approximate standard deviation of ImageNet pixels |
def denoise(noisy_image, sigma): |
noisy_image_variance = sigma``*``*``2 + sigma_data``*``*``2 |
scaled_noisy_image = noisy_image / noisy_image_variance *``* 0.5 |
return net(scaled_noisy_image, sigma) |
此处,噪声图像除以其预期标准差,使其大致达到单位方差。
作为次要细节(未在此处显示),同样也会将噪声级别标签输入扭曲为net
使用对数函数使其更均匀地分布在 -1 到 1 的范围内。
预测图像与噪声
如果您熟悉现有的扩散方法,您可能已经注意到,大多数方法训练网络来预测噪声(单位方差),而不是训练清晰的图像,而是将其明确扩展到已知的噪声级别sigma
然后通过从输入中减去来恢复降噪图像。
事实证明,特别是在低噪声级别下,这是个好主意,但在高噪声级别下,则是个坏主意。由于大多数图像细节在相对较低的噪声级别下会突然显示出来,因此好处大于缺点。
为什么在低噪声级别下这样做是个好主意?这种方法从输入中回收近乎清晰的图像,并且仅使用网络向其添加少量的噪声校正。重要的是,网络输出显式缩小(通过sigma
)来匹配噪声级别。因此,如果网络发生了一些错误(就像往常一样),该错误也会缩小,并且没有机会搞乱图像。这可以最大限度地减少不可靠的学习网络的贡献,并最大限度地重复使用输入中已知的内容。
为什么在高噪声级别下这是个坏主意?它最终会根据大噪声大小提高网络输出。因此,网络发生的任何小错误现在都会成为降噪器输出中的大错误。
更好的选择是持续过渡,其中网络=预测(负)噪声和清晰图像的噪声级别相关混合。然后将其与适当数量的噪声输入混合,以消除噪声。
本文介绍了一种计算混合权重作为噪声级别函数的原则性方法。确切的统计参数在某种程度上涉及,因此本文不会尝试完整复制它。基本上,它询问的是导致网络输出放大最小的混合系数。实现非常简单。最后一个返回行替换为以下代码,其中c_skip
和c_out
混合因子分别控制输入的回收量和网络的贡献量。
return c_skip * noisy_image + c_out * net(scaled_noisy_image, sigma) |
均衡噪声级别的梯度反馈大小
完成降噪器内部结构后,本节将解决 straw-man 训练代码片段中的噪声级别问题。不对损失应用任何与噪声级别相关的缩放是一种(较差)隐性选择。就像编写了以下内容:
weight = 1 |
loss = weight * (denoised_image - clean_image).square().``sum``() |
问题在于,由于降噪器内部的各种缩放,此损失值对于某些噪声级别来说很大,而对于其他噪声级别来说则较小。因此,对网络权重进行的更新(梯度反馈)的大小也将取决于噪声级别。这就像对不同的噪声级别使用不同的学习率,没有充分的理由。
在另一种情况下,统一大小会导致训练更加稳定和成功。幸运的是,一个简单的独立于数据的统计公式给出了每个噪声级别的预期损失幅度。weight
相应地将大小调整回 1.
分配 训练工作量
一种很有吸引力的误用weight
还可以根据噪音的相对重要性来衡量噪音水平,以便在需要的地方引导更多的网络容量。但是,通过在这些重要的噪音水平上更频繁地进行训练,可以在不影响强度的情况下实现相同的目标。图 13 从概念上说明了团队所主张的劳动分工。
在整个训练过程中,每个噪声级别都会为网络权重提供梯度更新(箭头)。另外,我们使用两个各自的机制来控制这些更新的大小和数量。默认情况下,大小(箭头的长度)和频率(数量)都不加控制地取决于噪声级别。该团队主张进行劳动分工,其中损失扩展会标准化长度,而噪声级别分布决定在每个级别的训练频率。
图 13.默认情况下,大小(箭头长度)和频率(数量)均取决于不受控制的噪声级别
不出所料,从均匀分布中选择训练噪声级别的代码示例并非易事。该理论在此选择中提供的指导很少,因为它取决于数据集的特征。在非常低的噪声级别下,进展极小,因为预测无噪点图像的噪声实际上是不可能的(但也无关紧要)。相反,在非常高的噪声级别下,优化降噪(数据集图像的模糊平均值)相当容易预测。中间部分提供了可以取得进展的广泛级别。
在实践中,我们从公式中选择了随机训练噪声级别,sigma = torch.exp(P_mean + P_std * torch.randn([]))
在哪里P_mean
和P_std
指定用于训练的平均噪声级别,以及该值周围的随机化宽度。选择此特定公式的原因很简单,因为它是绘制跨多个数量级的非负随机值的直接启发式方法。这些参数的值经过经验调优,但在常规图像数据集中证明相当可靠。
总结一下,以下是一个最小部分,其中汇集了原始训练代码中讨论过的所有更改,包括任何省略的公式:
P_mean = -``1.2 # average noise level (logarithmic) |
P_std = 1.2 # spread of random noise levels |
sigma_data = 0.5 # ImageNet standard deviation |
def denoise(noisy_image, sigma): |
# Input, output and skip scale |
c_in = 1 / torch.sqrt(sigma_data``*``*``2 + sigma``*``*``2``) |
c_out = sigma * sigma_data / torch.sqrt(sigma``*``*``2 + sigma_data``*``*``2``) |
c_skip = sigma_data``*``*``2 / (sigma``*``*``2 + sigma_data``*``*``2``) |
c_noise = torch.log(sigma) / 4 # noise label warp |
# mix the input and network output to extract the clean image |
return c_skip * noisy_image + \ |
c_out * net(c_in * noisy_image, c_noise) |
for clean_image in training_data: # we'll ignore minibatching for brevity |
# random noise level |
sigma = torch.exp(P_mean + P_std * torch.randn([])) |
noisy_image = clean_image \ |
+ sigma * torch.randn_like(clean_image) |
denoised_image = denoise(noisy_image, sigma) |
# weighted least squares loss |
weight = (sigma``*``*``2 + sigma_data``*``*``2``) / (sigma * sigma_data)``*``*``2 |
loss = weight * (denoised_image - clean_image).square().``sum``() |
# ... plus backpropagation and optimizer update |
结果和结论
本博文中展示的所有结果都通过彻底的数值实验证明是有益的,详情请参阅《阐明基于扩散的生成模型的设计空间》。采用所有改进的最终效果是对先前工作的显著进步。特别是,在竞争激烈的 ImageNet 64 × 64 类别中,我们保持了世界纪录的 FID 指标一段时间。此外,我们在生成过程中大幅减少了降噪器评估的数量,从而实现了这一记录。
我们相信,这些发现对于未来其他数据模式、改进的网络架构或更高分辨率的图像仍然具有相关性。当然,在不同的环境中应用模型时,我们仍应注意基本推理。例如,在采用潜在扩散或提高分辨率时,许多常量(例如最大噪声级别为 80,或训练噪声级别分布的位置和宽度)肯定需要进行调整。
要查看我们的官方实现以及预训练网络,请访问 NVlabs/edm GitHub 上的代码。该代码是一种简洁且精简的实现,遵循论文的符号和惯例,可以作为实验和构建这些想法的绝佳起点。请注意,我们包含了多个函数和类,这些函数和类用于重现先前方法以便进行比较,但使用或学习我们的方法并不需要这些函数和类。有关特别相关的代码,请参阅:
generate.py
edm_sampler
实现了完整的采样器,包括可选的随机性
training/
loss.EDMLoss
损失函数和权重networks.EDMPrecond
用于规模管理和混合预测networks.DhariwalUNet
用于重新实现常用的 ADM 网络架构
该团队最近发布了一篇后续研究论文,分析和改进扩散模型的训练动力学。在这项工作中,他们通过深入研究降噪器网络的设计和训练,实现了前所未有的生成质量。