扩散模型文本组合拳👊


重要事情说三遍:

本文不适合零基础阅读,需要有一定扩散模型在NLP领域的基础。

本文不适合零基础阅读,需要有一定扩散模型在NLP领域的基础。

本文不适合零基础阅读,需要有一定扩散模型在NLP领域的基础。

👩:我简单说两句

扩散模型是一类新兴的生成模型,扩散模型将高斯分布中抽取的随机噪声转换为由样本集合指定的未知数据分布中的样本。这类生成模型最近在连续数据域取得了广泛成功,如图像、音频和视频。扩散模型在可控生成方面也取得了巨大成功,比如最近大火的各式各样的AI作画。

语言生成模型目前主要是基于Transformer的自回归模型,虽然语言模型的生成能力有目共睹了,但是语言模型的可控生成依旧是很重要的一个方向。

现在扩散模型比较热门的一个方向是如何将扩散模型应用到文本领域。这一方向主要存在几个思考点:

  • 因为文本离散数据和图像等连续域数据的差异性,似的扩散模型在离散领域的应用有限,其中离散状态向高斯噪声的逐渐过渡并不像在连续领域(如自然图像)中那样自然。

  • 扩散模型具有优秀的受控特性,如果能将其引用到文本领域,我们可以获得更好的受控生成文本。

向文本领域的一些迁移:

先前的工作曾尝试在离散状态空间上定义扩散过程,以直接建模离散数据,但这些方法在连续扩散模型之后进展较慢。

比如:

  • 《Structured Denoising Diffusion Models in Discrete State-Spaces》
  • 《Diffusionbert: Improving generative masked language models with diffusion models》
  • ...

单词会在生成序列的前向过程的开始阶段被替换为 [MASK],这一步被理解为加噪过程。然后,在去噪(恢复)阶段的结尾被逐渐恢复回其原始单词。

其他工作:

直接在词嵌入空间中学习连续扩散模型,并通过舍入步骤解码连续生成结果。这些工作将扩散模型作为可能替代自回归语言模型的方法来呈现。

今天我们要分享的这个论文,是将扩散视为自回归生成的补充工具,而不是替代品。使用连续扩散模型在预训练的编码器-解码器语言模型的潜在空间中进行学习。然后,从扩散模型中采样的连续向量可以通过预训练解码器解码为自然语言。


介绍一下做法

扩散模型

自回归语言模型是直接去建模语言的分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( x ) P(x) </math>P(x),扩散模型与之不同,扩散模型做的是建立未知数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( x ) P(x) </math>P(x) 和高斯分布之间的映射关系。映射是通过前向过程来定义的,该过程通过将高斯噪声逐步添加到从数据分布中采样得到的样本中,并通过生成过程来逐步"去噪"从高斯分布中的样本,以获取从数据分布中采样的样本。

讲述基础扩散模型ddpm的文章现在已经有很多了,这里就不再展开赘述了。放一下DDPM前向和反向扩散的算法公式:

方法

潜变量空间模型

看训练部分。 作者是在编码器-解码器构架的语言模型的中间向量表示中进行扩散的。作者使用的是BART模型,模型构架中可训练的网络只有绿色的Diffusion Transformer这一部分。

使用冻结的BAET的编码器将句子映射到连续的潜变量空间上,在这个中间变量上进行扩散模型操作,扩散结果使用BART的解码器重建结果。整个过程可以描述为:

句子 <math xmlns="http://www.w3.org/1998/Math/MathML"> w w </math>w --> BART编码结果 <math xmlns="http://www.w3.org/1998/Math/MathML"> E ( w ) E(w) </math>E(w) --> BART编码器重建结果 <math xmlns="http://www.w3.org/1998/Math/MathML"> w ≈ w ~ = D ( E ( w ) ) w ≈ \tilde w = D(E(w)) </math>w≈w~=D(E(w))

训练的Diffusion Transformer是在给定的自然语言数据集下采样连续数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> x = E ( w ) x = E(w) </math>x=E(w)和 <math xmlns="http://www.w3.org/1998/Math/MathML"> w ∼ D w \sim D </math>w∼D,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D为数据集。使用这些数据训练我们的扩散模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ^ θ ( z t , t ) \hat x_\theta(z_t,t) </math>x^θ(zt,t),从高斯噪声中恢复 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x。

在生成过程中, 我们需要采样一个潜变量 <math xmlns="http://www.w3.org/1998/Math/MathML"> z T ∼ N ( 0 , I ) z_T \sim \mathcal N(0,I) </math>zT∼N(0,I),然后对其逐步去噪,获得BART解码器可以解码的向量。图像领域通常是在固定分辨率下采样一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> z T ∈ R H × W × 3 ∼ N ( 0 , I ) z_T \in \mathbb R ^{H \times W \times 3} \sim \mathcal N(0,I) </math>zT∈RH×W×3∼N(0,I)。但是文本不行,文本每个句子的长度不同的,因此我们要采样的 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∈ R l × d x \in \mathbb R ^{l \times d} </math>x∈Rl×d其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> l l </math>l是一个变量,因此在文本扩散模型的推理过程中需要指定一些 <math xmlns="http://www.w3.org/1998/Math/MathML"> l i l_i </math>li才可以采样 <math xmlns="http://www.w3.org/1998/Math/MathML"> z T ∈ R l i × d ∼ N ( 0 , I ) z_T \in \mathbb R ^{l_i \times d} \sim \mathcal N(0,I) </math>zT∈Rli×d∼N(0,I)。

那如何确定 <math xmlns="http://www.w3.org/1998/Math/MathML"> l i l_i </math>li?我们从数据给出的长度经验分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> L ( D ) \mathcal{L}(\mathcal{D}) </math>L(D)中采样 <math xmlns="http://www.w3.org/1998/Math/MathML"> l i l_i </math>li。

<math xmlns="http://www.w3.org/1998/Math/MathML"> Pr ⁡ ( ℓ = ℓ i ) = 1 ∣ D ∣ ∑ w ∈ D 1 { w ∈ R ℓ i × ∣ V ∣ } \operatorname{Pr}\left(\ell=\ell_i\right)=\frac{1}{|\mathcal{D}|} \sum_{\mathbf{w} \in \mathcal{D}} \mathbb{1}\left\{\mathbf{w} \in \mathbb{R}^{\ell_i \times|\mathcal{V}|}\right\} </math>Pr(ℓ=ℓi)=∣D∣1∑w∈D1{w∈Rℓi×∣V∣}.

为了生成过程,我们先对潜变量进行采样,首先采样一定长度的 <math xmlns="http://www.w3.org/1998/Math/MathML"> l i ∼ L ( D ) l_i \sim \mathcal{L}(\mathcal{D}) </math>li∼L(D) ,然后采样潜变量 <math xmlns="http://www.w3.org/1998/Math/MathML"> z T ∈ R l i × d ∼ N ( 0 , I ) z_T \in \mathbb{R}^{l_i \times d} \sim \mathcal{N}(0, I) </math>zT∈Rli×d∼N(0,I)。

self-condition

这里是借助了《Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning》中的self-condition,这篇文章是Hinton组的作品,使用二进制位的思想用扩散模型生成离散数据,实验中包含了图像字幕生成任务。

使用self-condition可以提高采样的质量。

普通的扩散模型在推理过程中是一个关于潜变量和时间 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t的计算: <math xmlns="http://www.w3.org/1998/Math/MathML"> x ~ t = x ^ θ ( z t , t ) \tilde{x}t=\hat{x}\theta\left(z_t, t \right) </math>x~t=x^θ(zt,t)

使用self-condition就是在前一时间步的输出上对网络进行条件化: <math xmlns="http://www.w3.org/1998/Math/MathML"> x ~ t = x ^ θ ( z t , t , x ~ t + 1 ) \tilde{x}t=\hat{x}\theta\left(z_t, t, \tilde{x}_{t+1}\right) </math>x~t=x^θ(zt,t,x~t+1)

Class-Conditional Diffusion

假设我们的数据集是有 <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C类标签的数据,因此我们数据集可以表示为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( w , y ) ∈ D , w ∈ R ℓ × ∣ V ∣ , y ∈ { 1 , 2 , . . . , C } (\mathbf w,y) \in \mathcal D, \mathbf w \in \mathbb R^{\ell \times |\mathcal V|},y \in \{1,2,...,C \} </math>(w,y)∈D,w∈Rℓ×∣V∣,y∈{1,2,...,C}。

我们通过在类标签上条件化来训练一个类条件扩散模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ~ t = x ^ θ ( z t , t , y ) \tilde{\mathbf x}{t} = \hat{\mathbf x}\theta(\mathbf z_t,t, y) </math>x~t=x^θ(zt,t,y)。 与Self-Conditioning一样,

以概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> p p </math>p将条件 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i y_i </math>yi替换为空标签 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ϕ y_\phi </math>yϕ以保持无条件生成的能力。

在推理时,我们可以选择一些类 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y来指导采样过程。采样潜变量 <math xmlns="http://www.w3.org/1998/Math/MathML"> z T ∼ N ( 0 , I ) \mathbf{z}T \sim \mathcal{N}(0, \mathbf{I}) </math>zT∼N(0,I)并给定类标签 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y,计算所有时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t的 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ~ t = x ^ θ ( z t , t , y ) \tilde{\mathbf x}{t} = \hat{\mathbf x}_\theta(\mathbf z_t,t, y) </math>x~t=x^θ(zt,t,y)以此进行控制文本生成。


完整模型

完整的过程表示如下:

相关推荐
白榆maple13 分钟前
(蓝桥杯C/C++)——基础算法(下)
算法
喵~来学编程啦14 分钟前
【论文精读】LPT: Long-tailed prompt tuning for image classification
人工智能·深度学习·机器学习·计算机视觉·论文笔记
JSU_曾是此间年少17 分钟前
数据结构——线性表与链表
数据结构·c++·算法
此生只爱蛋1 小时前
【手撕排序2】快速排序
c语言·c++·算法·排序算法
咕咕吖2 小时前
对称二叉树(力扣101)
算法·leetcode·职场和发展
-Nemophilist-2 小时前
机器学习与深度学习-1-线性回归从零开始实现
深度学习·机器学习·线性回归
九圣残炎2 小时前
【从零开始的LeetCode-算法】1456. 定长子串中元音的最大数目
java·算法·leetcode
lulu_gh_yu2 小时前
数据结构之排序补充
c语言·开发语言·数据结构·c++·学习·算法·排序算法
丫头,冲鸭!!!3 小时前
B树(B-Tree)和B+树(B+ Tree)
笔记·算法
Re.不晚3 小时前
Java入门15——抽象类
java·开发语言·学习·算法·intellij-idea