扩散模型文本组合拳👊


重要事情说三遍:

本文不适合零基础阅读,需要有一定扩散模型在NLP领域的基础。

本文不适合零基础阅读,需要有一定扩散模型在NLP领域的基础。

本文不适合零基础阅读,需要有一定扩散模型在NLP领域的基础。

👩:我简单说两句

扩散模型是一类新兴的生成模型,扩散模型将高斯分布中抽取的随机噪声转换为由样本集合指定的未知数据分布中的样本。这类生成模型最近在连续数据域取得了广泛成功,如图像、音频和视频。扩散模型在可控生成方面也取得了巨大成功,比如最近大火的各式各样的AI作画。

语言生成模型目前主要是基于Transformer的自回归模型,虽然语言模型的生成能力有目共睹了,但是语言模型的可控生成依旧是很重要的一个方向。

现在扩散模型比较热门的一个方向是如何将扩散模型应用到文本领域。这一方向主要存在几个思考点:

  • 因为文本离散数据和图像等连续域数据的差异性,似的扩散模型在离散领域的应用有限,其中离散状态向高斯噪声的逐渐过渡并不像在连续领域(如自然图像)中那样自然。

  • 扩散模型具有优秀的受控特性,如果能将其引用到文本领域,我们可以获得更好的受控生成文本。

向文本领域的一些迁移:

先前的工作曾尝试在离散状态空间上定义扩散过程,以直接建模离散数据,但这些方法在连续扩散模型之后进展较慢。

比如:

  • 《Structured Denoising Diffusion Models in Discrete State-Spaces》
  • 《Diffusionbert: Improving generative masked language models with diffusion models》
  • ...

单词会在生成序列的前向过程的开始阶段被替换为 MASK,这一步被理解为加噪过程。然后,在去噪(恢复)阶段的结尾被逐渐恢复回其原始单词。

其他工作:

直接在词嵌入空间中学习连续扩散模型,并通过舍入步骤解码连续生成结果。这些工作将扩散模型作为可能替代自回归语言模型的方法来呈现。

今天我们要分享的这个论文,是将扩散视为自回归生成的补充工具,而不是替代品。使用连续扩散模型在预训练的编码器-解码器语言模型的潜在空间中进行学习。然后,从扩散模型中采样的连续向量可以通过预训练解码器解码为自然语言。


介绍一下做法

扩散模型

自回归语言模型是直接去建模语言的分布 P ( x ) P(x) P(x),扩散模型与之不同,扩散模型做的是建立未知数据分布 P ( x ) P(x) P(x) 和高斯分布之间的映射关系。映射是通过前向过程来定义的,该过程通过将高斯噪声逐步添加到从数据分布中采样得到的样本中,并通过生成过程来逐步"去噪"从高斯分布中的样本,以获取从数据分布中采样的样本。

讲述基础扩散模型ddpm的文章现在已经有很多了,这里就不再展开赘述了。放一下DDPM前向和反向扩散的算法公式:

方法

潜变量空间模型

看训练部分。 作者是在编码器-解码器构架的语言模型的中间向量表示中进行扩散的。作者使用的是BART模型,模型构架中可训练的网络只有绿色的Diffusion Transformer这一部分。

使用冻结的BAET的编码器将句子映射到连续的潜变量空间上,在这个中间变量上进行扩散模型操作,扩散结果使用BART的解码器重建结果。整个过程可以描述为:

句子 w w w --> BART编码结果 E ( w ) E(w) E(w) --> BART编码器重建结果 w ≈ w ~ = D ( E ( w ) ) w ≈ \tilde w = D(E(w)) w≈w~=D(E(w))

训练的Diffusion Transformer是在给定的自然语言数据集下采样连续数据 x = E ( w ) x = E(w) x=E(w)和 w ∼ D w \sim D w∼D,其中 D D D为数据集。使用这些数据训练我们的扩散模型 x ^ θ ( z t , t ) \hat x_\theta(z_t,t) x^θ(zt,t),从高斯噪声中恢复 x x x。

在生成过程中, 我们需要采样一个潜变量 z T ∼ N ( 0 , I ) z_T \sim \mathcal N(0,I) zT∼N(0,I),然后对其逐步去噪,获得BART解码器可以解码的向量。图像领域通常是在固定分辨率下采样一个 z T ∈ R H × W × 3 ∼ N ( 0 , I ) z_T \in \mathbb R ^{H \times W \times 3} \sim \mathcal N(0,I) zT∈RH×W×3∼N(0,I)。但是文本不行,文本每个句子的长度不同的,因此我们要采样的 x ∈ R l × d x \in \mathbb R ^{l \times d} x∈Rl×d其中 l l l是一个变量,因此在文本扩散模型的推理过程中需要指定一些 l i l_i li才可以采样 z T ∈ R l i × d ∼ N ( 0 , I ) z_T \in \mathbb R ^{l_i \times d} \sim \mathcal N(0,I) zT∈Rli×d∼N(0,I)。

那如何确定 l i l_i li?我们从数据给出的长度经验分布 L ( D ) \mathcal{L}(\mathcal{D}) L(D)中采样 l i l_i li。

Pr ⁡ ( ℓ = ℓ i ) = 1 ∣ D ∣ ∑ w ∈ D 1 { w ∈ R ℓ i × ∣ V ∣ } \operatorname{Pr}\left(\ell=\ell_i\right)=\frac{1}{|\mathcal{D}|} \sum_{\mathbf{w} \in \mathcal{D}} \mathbb{1}\left\{\mathbf{w} \in \mathbb{R}^{\ell_i \times|\mathcal{V}|}\right\} Pr(ℓ=ℓi)=∣D∣1∑w∈D1{w∈Rℓi×∣V∣}.

为了生成过程,我们先对潜变量进行采样,首先采样一定长度的 l i ∼ L ( D ) l_i \sim \mathcal{L}(\mathcal{D}) li∼L(D) ,然后采样潜变量 z T ∈ R l i × d ∼ N ( 0 , I ) z_T \in \mathbb{R}^{l_i \times d} \sim \mathcal{N}(0, I) zT∈Rli×d∼N(0,I)。

self-condition

这里是借助了《Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning》中的self-condition,这篇文章是Hinton组的作品,使用二进制位的思想用扩散模型生成离散数据,实验中包含了图像字幕生成任务。

使用self-condition可以提高采样的质量。

普通的扩散模型在推理过程中是一个关于潜变量和时间 t t t的计算: x ~ t = x ^ θ ( z t , t ) \tilde{x}t=\hat{x}\theta\left(z_t, t \right) x~t=x^θ(zt,t)

使用self-condition就是在前一时间步的输出上对网络进行条件化: x ~ t = x ^ θ ( z t , t , x ~ t + 1 ) \tilde{x}t=\hat{x}\theta\left(z_t, t, \tilde{x}_{t+1}\right) x~t=x^θ(zt,t,x~t+1)

Class-Conditional Diffusion

假设我们的数据集是有 C C C类标签的数据,因此我们数据集可以表示为 ( w , y ) ∈ D , w ∈ R ℓ × ∣ V ∣ , y ∈ { 1 , 2 , . . . , C } (\mathbf w,y) \in \mathcal D, \mathbf w \in \mathbb R^{\ell \times |\mathcal V|},y \in \{1,2,...,C \} (w,y)∈D,w∈Rℓ×∣V∣,y∈{1,2,...,C}。

我们通过在类标签上条件化来训练一个类条件扩散模型 x ~ t = x ^ θ ( z t , t , y ) \tilde{\mathbf x}{t} = \hat{\mathbf x}\theta(\mathbf z_t,t, y) x~t=x^θ(zt,t,y)。 与Self-Conditioning一样,

以概率 p p p将条件 y i y_i yi替换为空标签 y ϕ y_\phi yϕ以保持无条件生成的能力。

在推理时,我们可以选择一些类 y y y来指导采样过程。采样潜变量 z T ∼ N ( 0 , I ) \mathbf{z}T \sim \mathcal{N}(0, \mathbf{I}) zT∼N(0,I)并给定类标签 y y y,计算所有时间步 t t t的 x ~ t = x ^ θ ( z t , t , y ) \tilde{\mathbf x}{t} = \hat{\mathbf x}_\theta(\mathbf z_t,t, y) x~t=x^θ(zt,t,y)以此进行控制文本生成。


完整模型

完整的过程表示如下:

相关推荐
小羊在睡觉3 小时前
力扣84. 柱状图中最大的矩形
后端·算法·leetcode·golang·go
3DVisionary3 小时前
蓝光三维扫描:医疗制造的精度焦虑怎么解
人工智能·算法·制造·蓝光三维扫描·医疗制造·三维检测·义齿检测
好评笔记3 小时前
机器学习面试八股——常用损失函数
人工智能·深度学习·算法·机器学习·校招
weixin_468466853 小时前
全局与局部注意力机制新手实战指南
人工智能·python·深度学习·算法·自然语言处理·transformer·注意力机制
小糖学代码4 小时前
LLM系列:环境搭建:5.Python-dotenv 环境变量管理
人工智能·python·深度学习·神经网络
_日拱一卒4 小时前
LeetCode:994腐烂的橘子
java·数据结构·算法·leetcode·深度优先
珂朵莉MM4 小时前
第七届全球校园人工智能算法精英大赛-算法巅峰赛产业命题赛第3赛季优化题--束搜索
人工智能·算法
Omics Pro5 小时前
首个!外源天然产物综合性代谢图谱
数据库·人工智能·算法·机器学习·r语言
voidmort5 小时前
3. 微调(Fine-tuning)与强化学习(RL)的核心思想
python·深度学习·算法
人道领域6 小时前
【LeetCode刷题日记】669.修剪二叉搜索树
开发语言·python·算法