【论文阅读】A Closer Look at Parameter-Efficient Tuning in Diffusion Models

Abstract

大规模扩散模型功能强大,但微调定制这些模型,内存和时间效率都很低。

本文通过向大规模扩散模型中插入小的学习器(称为adapters),实现有效的参数微调。

特别地,将适配器的设计空间分解为输入位置、输出位置、函数形式的正交因子,执行方差分析(ANOVA),是一种分析离散和连续变量之间相关性的经典统计方法。

分析表明,适配器的输入位置是影响下游任务性能的关键因素,将适配器放到cross-attention块之后可以获得最好的性能。

1 Introduction

借助从海量数据中学习到的知识,大规模扩散模型作为下游任务的强先验。其中,DreamBooth对大规模扩散模型中的所有参数进行调优,生成用户想要的特定对象。但微调整个模型在计算、内存和存储成本方面是低效的。

另一种可选方法是源于NLP领域的参数高效迁移学习方法(parameter-efficient transfer learning methods),向模型中插入小型可训练模块(称为adapters),并冻结原始模型。这个方法在扩散模型领域未得到深入的研究。扩散模型中广泛使用U-Net结构,包含了更多的组件,如带有下/上采样算子的残差快、自注意力和交叉注意力等。这导致这个方法比基于Transformer的语言模型有更大的设计空间。

本文首次系统研究了大规模扩散模型中参数有效调节的设计空间。

将Stable Diffusion作为具体案例(唯一开源大规模扩散模型)。特别地,将适配器的设计空间分解为正交因子------输入位置、输出位置和函数形式。通过使用Analysis of Variance(ANOVA)对这些因素进行分析,发现输入位置是影响下游任务表现的关键因素。然后仔细研究了输入位置的选择,发现将输入位置放在交叉注意力块之后可以最大限度地鼓励网格感知输入prompt的变化,从而导致最好的性能。

根据本文的研究,在Dreambooth中介绍的个性化任务和在一组小的文本-图像对上进行微调的任务上,本文的最佳设置可以在0.75%额外参数内达到与微调方法相当甚至更好的结果。

2 Background

2.1 Diffusion Models

扩散模型通过逆转下式表示的加噪过程学习 q ( x 0 ) q(x_0) q(x0)的数据分布:
q ( x 1 : T ∣ x 0 ) = ∏ t = 0 T q ( x t ∣ x t − 1 ) q(x_{1:T}|x_0)=\prod_{t=0}^T q(x_t|x_{t-1}) q(x1:T∣x0)=t=0∏Tq(xt∣xt−1)

其中 q ( x t ∣ x t − 1 ) = N ( x t ∣ α t x t − 1 , β t I ) q(x_t|x_{t-1})=\mathcal{N}(x_t|\sqrt{\alpha_t}x_{t-1},\beta_t I) q(xt∣xt−1)=N(xt∣αt xt−1,βtI)对应一步加噪过程。反向的转变用一个高斯模型 p ( x t − 1 ∣ x t ) = N ( x t − 1 ∣ μ ( x t ) , σ t 2 I ) p(x_{t-1}|x_t)=\mathcal{N}(x_{t-1}|\mu(x_t),\sigma_t^2 I) p(xt−1∣xt)=N(xt−1∣μ(xt),σt2I)来近似,其中极大似然估计下的最优均值为
μ t ∗ ( x t ) = 1 α t ( x t − β t 1 − α ˉ t E [ ϵ ∣ x t ] ) \mu_t^*(x_t)=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{1-\bar{\alpha}_t}\mathbb{E}[\epsilon|x_t]) μt∗(xt)=αt 1(xt−1−αˉtβtE[ϵ∣xt])

其中 α t ˉ = ∏ i = 1 t α i \bar{\alpha_t}=\prod_{i=1}^t \alpha_i αtˉ=∏i=1tαi, ϵ \epsilon ϵ是向 x t x_t xt注入的标准高斯噪声。为了获得最优均值,只需通过下式的噪声预测目标来估计条件期望 E [ ϵ ∣ x t ] \mathbb{E}[\epsilon|x_t] E[ϵ∣xt]:
min ⁡ θ E t , x 0 ∣ ∣ ϵ θ ( x t , t ) − ϵ ∣ ∣ 2 2 \min_\theta\mathbb{E}{t,x_0}||\epsilon\theta(x_t,t)-\epsilon||_2^2 θminEt,x0∣∣ϵθ(xt,t)−ϵ∣∣22

其中 ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) ϵθ(xt,t)是噪声预测网络,根据 ℓ 2 \ell_2 ℓ2损失的惩罚项找到最优解,满足 ϵ θ ∗ ( x t , t ) = E [ ϵ ∣ x t ] \epsilon_{\theta*}(x_t,t)=\mathbb{E}[\epsilon|x_t] ϵθ∗(xt,t)=E[ϵ∣xt]。

实践中,常常关心条件生成,为了使用扩散模型,只需要在训练时将条件信息 c c c引入到噪声预测网络中:
min ⁡ θ E t , x 0 , c ∣ ∣ ϵ θ ( x t , t , c ) − ϵ ∣ ∣ 2 2 \min_\theta\mathbb{E}{t,x_0,c}||\epsilon\theta(x_t,t,c)-\epsilon||_2^2 θminEt,x0,c∣∣ϵθ(xt,t,c)−ϵ∣∣22

2.2 The Architecture in Stable Diffusion

U-Net架构是目前扩散模型最热门的架构,架构图见图3。

图3:顶部左侧展示了基于UNet的扩散模型的框架,顶部右侧展示了扩散模型是如何从第 T − 1 T-1 T−1步从噪声数据去除噪声的。底部图展示了架构中残差快和Transformer块。适配器(图中红色块)是参数高效迁移学习方法中插入模型中的带有少量参数的模块。

U-Net由堆叠的基本块组成,每个基本块包含一个Transformer块和一个残差块。

Transformer块中,由三种子层:一个自注意力层(self-attention layer),一个交叉注意力层(cross attention layer),一个全连接前馈网络(fully connected feed-forward network)。

注意力层在queries Q ∈ R n × d k Q\in\mathbb{R}^{n\times d_k} Q∈Rn×dk、key-value pairs K ∈ R m × d k , V ∈ R m × d v K\in\mathbf{R}^{m\times d_k},V\in\mathbf{R}^{m\times d_v} K∈Rm×dk,V∈Rm×dv上进行操作:
Attn ( Q , K , V ) ∈ R n × d v = softmax ( Q K T d k ) V (1) \text{Attn}(Q,K,V)\in\mathbb{R}^{n\times d_v}=\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V\tag{1} Attn(Q,K,V)∈Rn×dv=softmax(dk QKT)V(1)

其中 n n n是queries的数量, m m m是key-value pairs的数量, d k d_k dk是key的维度, d v d_v dv是value的维度。

在自注意力层中, x ∈ R n × d x x\in\mathbb{R}^{n\times d_x} x∈Rn×dx是唯一的输入。

在扩散模型的交叉注意力层中,有两个输入,分别是 x ∈ R n × d x x\in\mathbb{R}^{n\times d_x} x∈Rn×dx和 c ∈ R m × d c c\in\mathbb{R}^{m\times d_c} c∈Rm×dc, x x x是前面层的输出, c c c是条件信息。

全连接前馈网络包含两层线性变换的ReLU激活函数:
FFN ( x ) = ReLU ( x W 1 + b 1 ) W 2 + b 2 (2) \text{FFN}(x)=\text{ReLU}(xW_1+b_1)W_2+b_2\tag{2} FFN(x)=ReLU(xW1+b1)W2+b2(2)

其中 W 1 ∈ R d × d m , W 2 ∈ R d m × d W_1\in\mathbb{R}^{d\times d_m},W_2\in\mathbb{R}^{d_m\times d} W1∈Rd×dm,W2∈Rdm×d是可学习的权重, b 1 ∈ R d m , b 2 ∈ R d b_1\in\mathbb{R}^{d_m},b_2\in\mathbb{R}^d b1∈Rdm,b2∈Rd是可学习的偏差。

残差块中包含一系列的卷积层和激活函数,其中时间嵌入(time embedding)通过假发操作注入倒残差块中。

2.3 Parameter-Efficient Transfer Learning

迁移学习(Transfer Learning)是一种利用从一个任务中学习到的知识来提高相关任务性能的技术。对下游任务进行预训练然后进行迁移学习的方法被广泛使用。

传统的迁移方法需要大量参数,计算代价高,占用内存大。

参数有效的迁移学习(Parameter-Efficient Transfer Learning)最早在NLP领域中提出,核心思想是减少需要更新的参数数量。这是通过更新模型的一部分参数,或添加额外的小模块实现的。

3 Design Space of Parameter-Efficient Transfer Learning in Diffusion Models

本文将适配器的设计空间分解为三个正交因子:输入位置、输出位置、函数形式。

3.1 Input Position and Output Position

输入位置(Input position)是指适配器从哪里输入,输出位置(Output position)是指适配器向哪里输出。规定记号如图4所示,位置根据其相邻层命名。

图4:通常,激活位置的主要名称是模型中特定块的别名,激活位置的下标解释了激活与块之间的关系。

本文的框架中,输入位置可能是图4中描述的任意一个激活位置,因此一共有10个可选位置。

由于加法可交换,对于输出位置,有些位置是等价的,如 SA out \text{SA}\text{out} SAout和 CA in \text{CA}\text{in} CAin是等价的。因此有7种可选位置。

3.2 Function Form

函数形式描述了适配器如何将输入转换成输出。本文分别给出Transformer块和残差块的函数形式,见图5,二者都包含下采样操作、激活函数、上采样操作和放缩因子。

图5:Transformer块和残差块中适配器的函数形式。

下采样操作减少输入的维度,上采样操作增加输入的维度,以保证其具有和输入相同的维度。进一步将输出乘一个放缩因子 s s s,控制其对原网络的影响强度。'

具体地,Transformer块适配器用低秩矩阵 W d o w n W_{down} Wdown和 W u p W_{up} Wup作为下采样和上采样操作;残差块适配器用 3 × 3 3\times 3 3×3卷积层 Conv d o w n \text{Conv}{down} Convdown和 Conv u p \text{Conv}{up} Convup作为下采样和上采样操作。卷积核只改变通道数,不改变空间大小。

此外,残差块适配器还使用组归一化(group normalization)处理输入。

本文在设计选择中包含了不同的激活函数和缩放因子。

激活函数包括ReLU、Sigmoid、SiLU。

放缩因子包括0.5、1.0、2.0、4.0。

4 Discover the Key Factor with Analysis of Variance

本文采用单因素方差分析(one-way analysis of variance, ANOVA)方法量化模型性能与因素之间的相关性。

ANOVA的核心思想是将总方差分为两个部分:组内方差(variation within groups, MSE)和组间方差(variation between groups, MSB)。

MSB衡量不同组均值之间的差异,MSE衡量个体观测值和其对应组均值之间的差异。

ANOVA中使用的统计检验是基于F分布,即比较组间变异与组内变异的比值(F统计量,F-statistic)。

如果F统计量足够大,则表明组间均值存在显著性差异,表明具有较强的相关性。

相关推荐
Sxiaocai11 分钟前
使用 PyTorch 实现并训练 VGGNet 用于 MNIST 分类
pytorch·深度学习·分类
GL_Rain12 分钟前
【OpenCV】Could NOT find TIFF (missing: TIFF_LIBRARY TIFF_INCLUDE_DIR)
人工智能·opencv·计算机视觉
shansjqun16 分钟前
教学内容全覆盖:航拍杂草检测与分类
人工智能·分类·数据挖掘
狸克先生19 分钟前
如何用AI写小说(二):Gradio 超简单的网页前端交互
前端·人工智能·chatgpt·交互
baiduopenmap33 分钟前
百度世界2024精选公开课:基于地图智能体的导航出行AI应用创新实践
前端·人工智能·百度地图
小任同学Alex37 分钟前
浦语提示词工程实践(LangGPT版,服务器上部署internlm2-chat-1_8b,踩坑很多才完成的详细教程,)
人工智能·自然语言处理·大模型
新加坡内哥谈技术43 分钟前
微软 Ignite 2024 大会
人工智能
江瀚视野1 小时前
Q3净利增长超预期,文心大模型调用量大增,百度未来如何分析?
人工智能
陪学1 小时前
百度遭初创企业指控抄袭,维权还是碰瓷?
人工智能·百度·面试·职场和发展·产品运营
QCN_1 小时前
湘潭大学人工智能考试复习1(软件工程)
人工智能