比Stable Diffusion便宜118倍!1890美元训出11.6亿参数高质量文生图模型

**【新智元导读】**近日,来自加州大学尔湾分校等机构的研究人员,利用延迟掩蔽、MoE、分层扩展等策略,将扩散模型的训练成本降到了 1890 美元。

训练一个扩散模型要多少钱?

之前最便宜的方法(Wuerstchen)用了 28400 美元,而像 Stable Diffusion 这样的模型还要再贵一个数量级。

大模型时代,一般人根本玩不起。想要各种文生小姐姐,还得靠厂商们负重前行

为了降低这庞大的开销,研究者们尝试了各种方案。

比如,原始的扩散模型从噪声到图像大约需要 1000 步,目前已经被减少到 20 步左右,甚至更少。

当扩散模型中的基础模块逐渐由 Unet(CNN)替换为 DiT(Transformer)之后,一些根据 Transformer 特性来做的优化也跟了上来。

比如量化,比如跳过 Attention 中的一些冗余计算,比如 pipeline。

而近日,来自加州大学尔湾分校等机构的研究人员,把「省钱」这个目标直接向前推进了一大步:

论文地址:arxiv.org/abs/2407.15...

------从头开始训练一个 11.6 亿参数的扩散模型,只需要 1890 美元!

对比 SOTA 有了一个数量级的提升,让普通人也看到了能摸一摸预训练的希望。

更重要的是,降低成本的技术并没有影响模型的性能,11.6 亿个参数给出了下面这样非常不错的效果。

除了观感,模型的数据指标也很优秀,比如下表给出的 FID 分数,非常接近 Stable Diffusion 1.5 和 DALL·E 2。

相比之下,Wuerstchen 的降成本方案则导致自己的考试分数不甚理想。

省钱的秘诀

抱着「Stretching Each Dollar」的目标,研究人员从扩散模型的基础模块 DiT 入手。

首先,序列长度是 Transformer 计算成本的大敌,需要除掉。

对于图像来说,就需要在不影响性能的情况下,尽量减少参加计算的 patch 数量(同时也减少了内存开销)。

减少图像切块数可以有两种方式,一是增大每块的尺寸,二是干掉一部分 patch(mask)。

因为前者会显著降低模型性能,所以我们考虑进行 mask 的方式。

最朴素的 mask(Naive token masking)类似于卷积 UNet 中随机裁剪的训练,但允许对图像的非连续区域进行训练。

而之前最先进的方法(MaskDiT),在输出之前增加了一个恢复重建的结构,通过额外的损失函数来训练,希望通过学习弥补丢掉的信息。

这两种 mask 都为了降低计算成本,在一开始就丢弃了大部分 patch,信息的损失显著降低了 Transformer 的整体性能,即使 MaskDiT 试图弥补,也只是获得了不太多的改进。

------丢掉信息不可取,那么怎样才能减小输入又不丢信息呢?

延迟掩蔽

本文提出了一种延迟掩蔽策略(deferred masking strategy),在 mask 之前使用混合器(patch-mixer)进行预处理,把被丢弃 patch 的信息嵌入到幸存的 patch 中,从而显著减少高 mask 带来的性能下降。

在本架构中,patch-mixer 是通过注意力层和前馈层的组合来实现的,使用二进制掩码进行 mask,整个模型的损失函数为:

与 MaskDiT 相比,这里不需要额外的损失函数,整体设计和训练更加简单。

而混合器本身是个非常轻量的结构,符合省钱的标准。

微调

由于非常高的掩蔽比(masking ratio)会显著降低扩散模型学习图像中全局结构的能力,并引入训练到测试的分布偏移,所以作者在预训练(mask)后进行了小幅度的微调(unmask)。

另外,微调还可以减轻由于使用 mask 而产生的任何不良生成伪影。

MoE 和分层扩展

MoE 能够增加模型的参数和表达能力,而不会显著增加训练成本。

作者使用基于专家选择路由的简化 MoE 层,每个专家确定路由到它的 token,而不需要任何额外的辅助损失函数来平衡专家之间的负载。

此外,作者还考虑了分层缩放方法,线性增加 Transformer 块的宽度(即注意力层和前馈层中的隐藏层尺寸)。

由于视觉模型中的更深层倾向于学习更复杂的特征,因此在更深层中使用更多的参数将带来更好的性能。

实验设置

作者使用两种 DiT 的变体:DiT-Tiny/2 和 DiT-Xl/2,patch 大小为 2。

使用具有余弦学习率衰减和高权重衰减的 AdamW 优化器训练所有模型。

模型前端使用 Stable-Diffusion-XL 模型中的四通道变分自动编码器(VAE)来提取图像特征,另外还测试了最新的 16 通道 VAE 在大规模训练(省钱版)中的性能。

作者使用 EDM 框架作为所有扩散模型的统一训练设置,使用 FID 以及 CLIP 分数来衡量图像生成模型的性能。

文本编码器选择了最常用的 CLIP 模型,尽管 T5-xxl 这种较大的模型在文本合成等具有挑战性的任务上表现更好,但为了省钱的目标,这里没有采用。

训练数据集

使用三个真实图像数据集(Conceptual Captions、Segment Anything、TextCaps),包含 2200 万个图像文本对。

由于 SA1B 不提供真实的字幕,这里使用 LLaVA 模型生成的合成字幕。作者还在大规模训练中添加了两个包含 1500 万个图像文本对的合成图像数据集:JourneyDB 和 DiffusionDB。

对于小规模消融,研究人员通过从较大的 COYO-700M 数据集中对 10 个 CIFAR-10 类的图像进行二次采样,构建了一个名为 cifar-captions 的文本到图像数据集。

评估

使用 DiT-Tiny/2 模型和 cifar-captions 数据集(256×256 分辨率)进行所有评估实验。

对每个模型进行 60K 优化步骤的训练,并使用 AdamW 优化器和指数移动平均值(最后 10K 步平滑系数为 0.995)。

延迟掩蔽

实验的基线选择我们上面提到的 Naive masking,而本文的延迟掩蔽则加入一个轻量的 patch-mixer,参数量小于主干网络的 10%。

一般来说,丢掉的 patch 越多(高 masking ratio),模型的性能会越差,比如 MaskDiT 在超过 50% 后表现大幅下降。

这里的对比实验采用默认的超参数(学习率 1.6×10e-4、0.01 的权重衰减和余弦学习率)来训练两个模型。

上图的结果显示了延迟屏蔽方法在 FID、Clip-FID 和 Clip score 三个指标上都获得了提升。

并且,与基线的性能差距随着掩蔽率的增加而扩大。在掩蔽率为 75% 的情况下,朴素掩蔽会将 FID 分数降低至 16.5,而本文的方法则达到 5.03,更接近于无掩蔽时的 FID 分数(3.79)。

超参数

沿着训练 LLM 的一般思路,这里比较两个任务的超参数选择。

首先,在前馈层中,SwiGLU 激活函数优于 GELU。其次,较高的权重衰减会带来更好的图像生成性能。

另外,与 LLM 训练不同的是,当对 AdamW 二阶矩 (β) 使用更高的运行平均系数时,本文的扩散模型可以达到更好的性能。

最后,作者发现使用少量的训练步骤,而将学习率增加到最大可能值(直到训练不稳定)也显著提高了图像生成性能。

混合器的设计

大力出奇迹一般都是对的,作者也观察到使用更大的 patch-mixer 后,模型性能得到持续改善。

然而,本着省钱的目的,这里还是选择使用小型的混合器。

作者将噪声分布修改为 (−0.6, 1.2),这改善了字幕和生成图像之间的对齐。

如下图所示,在 75% masking ratio 下,作者还研究了采用不同 patch 大小所带来的影响。

当连续区域变多(patch 变大)时,模型的性能会下降,因此保留随机屏蔽每个 patch 的原始策略。

分层缩放

这个实验训练了 DiT-Tiny 架构的两种变体,一种具有恒定宽度,另一种采用分层缩放的结构。

两种方法都使用 Naive masking,并调整 Transformer 的尺寸,保证两种情况下的模型算力相同,同时执行相同的训练步骤和训练时间。

由上表结果可知发现,在所有三个性能指标上,分层缩放方法都优于基线的恒定宽度方法,这表明分层缩放方法更适合 DiT 的掩蔽训练。

参考资料:

arxiv.org/abs/2407.15...

相关推荐
深圳南柯电子2 分钟前
深圳南柯电子|电子设备EMC测试整改:常见问题与解决方案
人工智能
Kai HVZ3 分钟前
《OpenCV计算机视觉》--介绍及基础操作
人工智能·opencv·计算机视觉
biter00888 分钟前
opencv(15) OpenCV背景减除器(Background Subtractors)学习
人工智能·opencv·学习
吃个糖糖14 分钟前
35 Opencv 亚像素角点检测
人工智能·opencv·计算机视觉
IT古董1 小时前
【漫话机器学习系列】017.大O算法(Big-O Notation)
人工智能·机器学习
凯哥是个大帅比1 小时前
人工智能ACA(五)--深度学习基础
人工智能·深度学习
m0_748232921 小时前
DALL-M:基于大语言模型的上下文感知临床数据增强方法 ,补充
人工智能·语言模型·自然语言处理
szxinmai主板定制专家2 小时前
【国产NI替代】基于FPGA的32通道(24bits)高精度终端采集核心板卡
大数据·人工智能·fpga开发
海棠AI实验室2 小时前
AI的进阶之路:从机器学习到深度学习的演变(三)
人工智能·深度学习·机器学习
机器懒得学习2 小时前
基于YOLOv5的智能水域监测系统:从目标检测到自动报告生成
人工智能·yolo·目标检测