扩散模型文本组合拳👊


重要事情说三遍:

本文不适合零基础阅读,需要有一定扩散模型在NLP领域的基础。

本文不适合零基础阅读,需要有一定扩散模型在NLP领域的基础。

本文不适合零基础阅读,需要有一定扩散模型在NLP领域的基础。

👩:我简单说两句

扩散模型是一类新兴的生成模型,扩散模型将高斯分布中抽取的随机噪声转换为由样本集合指定的未知数据分布中的样本。这类生成模型最近在连续数据域取得了广泛成功,如图像、音频和视频。扩散模型在可控生成方面也取得了巨大成功,比如最近大火的各式各样的AI作画。

语言生成模型目前主要是基于Transformer的自回归模型,虽然语言模型的生成能力有目共睹了,但是语言模型的可控生成依旧是很重要的一个方向。

现在扩散模型比较热门的一个方向是如何将扩散模型应用到文本领域。这一方向主要存在几个思考点:

  • 因为文本离散数据和图像等连续域数据的差异性,似的扩散模型在离散领域的应用有限,其中离散状态向高斯噪声的逐渐过渡并不像在连续领域(如自然图像)中那样自然。

  • 扩散模型具有优秀的受控特性,如果能将其引用到文本领域,我们可以获得更好的受控生成文本。

向文本领域的一些迁移:

先前的工作曾尝试在离散状态空间上定义扩散过程,以直接建模离散数据,但这些方法在连续扩散模型之后进展较慢。

比如:

  • 《Structured Denoising Diffusion Models in Discrete State-Spaces》
  • 《Diffusionbert: Improving generative masked language models with diffusion models》
  • ...

单词会在生成序列的前向过程的开始阶段被替换为 [MASK],这一步被理解为加噪过程。然后,在去噪(恢复)阶段的结尾被逐渐恢复回其原始单词。

其他工作:

直接在词嵌入空间中学习连续扩散模型,并通过舍入步骤解码连续生成结果。这些工作将扩散模型作为可能替代自回归语言模型的方法来呈现。

今天我们要分享的这个论文,是将扩散视为自回归生成的补充工具,而不是替代品。使用连续扩散模型在预训练的编码器-解码器语言模型的潜在空间中进行学习。然后,从扩散模型中采样的连续向量可以通过预训练解码器解码为自然语言。


介绍一下做法

扩散模型

自回归语言模型是直接去建模语言的分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( x ) P(x) </math>P(x),扩散模型与之不同,扩散模型做的是建立未知数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( x ) P(x) </math>P(x) 和高斯分布之间的映射关系。映射是通过前向过程来定义的,该过程通过将高斯噪声逐步添加到从数据分布中采样得到的样本中,并通过生成过程来逐步"去噪"从高斯分布中的样本,以获取从数据分布中采样的样本。

讲述基础扩散模型ddpm的文章现在已经有很多了,这里就不再展开赘述了。放一下DDPM前向和反向扩散的算法公式:

方法

潜变量空间模型

看训练部分。 作者是在编码器-解码器构架的语言模型的中间向量表示中进行扩散的。作者使用的是BART模型,模型构架中可训练的网络只有绿色的Diffusion Transformer这一部分。

使用冻结的BAET的编码器将句子映射到连续的潜变量空间上,在这个中间变量上进行扩散模型操作,扩散结果使用BART的解码器重建结果。整个过程可以描述为:

句子 <math xmlns="http://www.w3.org/1998/Math/MathML"> w w </math>w --> BART编码结果 <math xmlns="http://www.w3.org/1998/Math/MathML"> E ( w ) E(w) </math>E(w) --> BART编码器重建结果 <math xmlns="http://www.w3.org/1998/Math/MathML"> w ≈ w ~ = D ( E ( w ) ) w ≈ \tilde w = D(E(w)) </math>w≈w~=D(E(w))

训练的Diffusion Transformer是在给定的自然语言数据集下采样连续数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> x = E ( w ) x = E(w) </math>x=E(w)和 <math xmlns="http://www.w3.org/1998/Math/MathML"> w ∼ D w \sim D </math>w∼D,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D为数据集。使用这些数据训练我们的扩散模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ^ θ ( z t , t ) \hat x_\theta(z_t,t) </math>x^θ(zt,t),从高斯噪声中恢复 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x。

在生成过程中, 我们需要采样一个潜变量 <math xmlns="http://www.w3.org/1998/Math/MathML"> z T ∼ N ( 0 , I ) z_T \sim \mathcal N(0,I) </math>zT∼N(0,I),然后对其逐步去噪,获得BART解码器可以解码的向量。图像领域通常是在固定分辨率下采样一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> z T ∈ R H × W × 3 ∼ N ( 0 , I ) z_T \in \mathbb R ^{H \times W \times 3} \sim \mathcal N(0,I) </math>zT∈RH×W×3∼N(0,I)。但是文本不行,文本每个句子的长度不同的,因此我们要采样的 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∈ R l × d x \in \mathbb R ^{l \times d} </math>x∈Rl×d其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> l l </math>l是一个变量,因此在文本扩散模型的推理过程中需要指定一些 <math xmlns="http://www.w3.org/1998/Math/MathML"> l i l_i </math>li才可以采样 <math xmlns="http://www.w3.org/1998/Math/MathML"> z T ∈ R l i × d ∼ N ( 0 , I ) z_T \in \mathbb R ^{l_i \times d} \sim \mathcal N(0,I) </math>zT∈Rli×d∼N(0,I)。

那如何确定 <math xmlns="http://www.w3.org/1998/Math/MathML"> l i l_i </math>li?我们从数据给出的长度经验分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> L ( D ) \mathcal{L}(\mathcal{D}) </math>L(D)中采样 <math xmlns="http://www.w3.org/1998/Math/MathML"> l i l_i </math>li。

<math xmlns="http://www.w3.org/1998/Math/MathML"> Pr ⁡ ( ℓ = ℓ i ) = 1 ∣ D ∣ ∑ w ∈ D 1 { w ∈ R ℓ i × ∣ V ∣ } \operatorname{Pr}\left(\ell=\ell_i\right)=\frac{1}{|\mathcal{D}|} \sum_{\mathbf{w} \in \mathcal{D}} \mathbb{1}\left\{\mathbf{w} \in \mathbb{R}^{\ell_i \times|\mathcal{V}|}\right\} </math>Pr(ℓ=ℓi)=∣D∣1∑w∈D1{w∈Rℓi×∣V∣}.

为了生成过程,我们先对潜变量进行采样,首先采样一定长度的 <math xmlns="http://www.w3.org/1998/Math/MathML"> l i ∼ L ( D ) l_i \sim \mathcal{L}(\mathcal{D}) </math>li∼L(D) ,然后采样潜变量 <math xmlns="http://www.w3.org/1998/Math/MathML"> z T ∈ R l i × d ∼ N ( 0 , I ) z_T \in \mathbb{R}^{l_i \times d} \sim \mathcal{N}(0, I) </math>zT∈Rli×d∼N(0,I)。

self-condition

这里是借助了《Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning》中的self-condition,这篇文章是Hinton组的作品,使用二进制位的思想用扩散模型生成离散数据,实验中包含了图像字幕生成任务。

使用self-condition可以提高采样的质量。

普通的扩散模型在推理过程中是一个关于潜变量和时间 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t的计算: <math xmlns="http://www.w3.org/1998/Math/MathML"> x ~ t = x ^ θ ( z t , t ) \tilde{x}t=\hat{x}\theta\left(z_t, t \right) </math>x~t=x^θ(zt,t)

使用self-condition就是在前一时间步的输出上对网络进行条件化: <math xmlns="http://www.w3.org/1998/Math/MathML"> x ~ t = x ^ θ ( z t , t , x ~ t + 1 ) \tilde{x}t=\hat{x}\theta\left(z_t, t, \tilde{x}_{t+1}\right) </math>x~t=x^θ(zt,t,x~t+1)

Class-Conditional Diffusion

假设我们的数据集是有 <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C类标签的数据,因此我们数据集可以表示为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( w , y ) ∈ D , w ∈ R ℓ × ∣ V ∣ , y ∈ { 1 , 2 , . . . , C } (\mathbf w,y) \in \mathcal D, \mathbf w \in \mathbb R^{\ell \times |\mathcal V|},y \in \{1,2,...,C \} </math>(w,y)∈D,w∈Rℓ×∣V∣,y∈{1,2,...,C}。

我们通过在类标签上条件化来训练一个类条件扩散模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ~ t = x ^ θ ( z t , t , y ) \tilde{\mathbf x}{t} = \hat{\mathbf x}\theta(\mathbf z_t,t, y) </math>x~t=x^θ(zt,t,y)。 与Self-Conditioning一样,

以概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> p p </math>p将条件 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i y_i </math>yi替换为空标签 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ϕ y_\phi </math>yϕ以保持无条件生成的能力。

在推理时,我们可以选择一些类 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y来指导采样过程。采样潜变量 <math xmlns="http://www.w3.org/1998/Math/MathML"> z T ∼ N ( 0 , I ) \mathbf{z}T \sim \mathcal{N}(0, \mathbf{I}) </math>zT∼N(0,I)并给定类标签 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y,计算所有时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t的 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ~ t = x ^ θ ( z t , t , y ) \tilde{\mathbf x}{t} = \hat{\mathbf x}_\theta(\mathbf z_t,t, y) </math>x~t=x^θ(zt,t,y)以此进行控制文本生成。


完整模型

完整的过程表示如下:

相关推荐
查理零世19 分钟前
【算法】经典博弈论问题——巴什博弈 python
开发语言·python·算法
神探阿航24 分钟前
第十五届蓝桥杯大赛软件赛省赛C/C++ 大学 B 组
java·算法·蓝桥杯
皮肤科大白42 分钟前
如何在data.table中处理缺失值
学习·算法·机器学习
有Li43 分钟前
基于深度学习的微出血自动检测及解剖尺度定位|文献速递-视觉大模型医疗图像应用
人工智能·深度学习
熙曦Sakura1 小时前
【深度学习】微积分
人工智能·深度学习
HyperAI超神经1 小时前
【TVM教程】为 ARM CPU 自动调优卷积网络
arm开发·人工智能·python·深度学习·机器学习·tvm·编译器
IT古董2 小时前
【深度学习】常见模型-卷积神经网络(Convolutional Neural Networks, CNN)
人工智能·深度学习·cnn
Luzem03192 小时前
使用scikit-learn中的KNN包实现对鸢尾花数据集的预测
人工智能·深度学习·机器学习
AI趋势预见2 小时前
使用AI生成金融时间序列数据:解决股市场的数据稀缺问题并提升信噪比
人工智能·深度学习·神经网络·语言模型·金融
不能只会打代码2 小时前
蓝桥杯例题一
算法·蓝桥杯