文本生成:解码器

最早期的生成任务,是大家所熟知的 "编码器-解码器结构", 简单来说,就是将输入的序列信息编码,由解码器对编码后的信息作出解释。这种结构已经在机器翻译、文本摘要、复述生 成等多种生成任务上得到应用,并取得了不错的效果。其本质就是将源语言序列映射到目标语言的词序列,可以被看做是一个以源端句子为条件,来建模目标端句子的语言模型。

解码器结构

如图所示,输入的是已经分好词的源端句子"what is zakat used for"。编码器接收到输入之后,首先会根据词表将词映射到对应的词嵌入表示 (word embedding),于是就得到了网络的输入 {x1, ..., xl},输入表示经过一个循环神经网络,得到句子表示 h,公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> h t = f ( x t , h t − 1 ) , h_t = f (x_t,h_{t−1}), </math>ht=f(xt,ht−1), t = 1,2,...,l and h0 = 0

编码器得到的表示,可以作为解码器的初始化状态,也就是第一个时刻的一个输入。解码器还需要输出层来做预测,预测下一个单词的生成概率,并作出决策。由于解码器每一个时刻都需要接受前一时刻的隐层和当前输入的词表示来计算当前时刻的隐层表示,并用输出层得到下一个要生成的单词的概率,因此,需要对输入的句子做一些标记,来告诉解码器什么时候开始解码,生成第一个单词,什么时候停止解码,也就是说需要标记生成的开始和结束。一般,在输入句子的开始和结束位置加上两个标记,"[S]"表示句子的开始,"[EOS]"表示句子结束。 如下公式展示了在 t 时刻解码器生成单词的计算过程:

<math xmlns="http://www.w3.org/1998/Math/MathML"> p ( y ) = s o f t m a x ( V h t + b ) p(y)=softmax(Vht+b) </math>p(y)=softmax(Vht+b)

其中 V 是权重矩阵,b 是偏置向量,softmax 是归一化指数函数,将网络输出值归一化成目标端输出词的概率。softmax函数的计算公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> s o f t m a x ( o k ) = e x p ( o k ) / ∑ k ′ = 1 K e x p ( o k ′ ) softmax(o^k)= exp(o^k)/\sum_{k'=1}^{K} exp(o^{k′}) </math>softmax(ok)=exp(ok)/∑k′=1Kexp(ok′)

解码第一个单词的时候,接收的输入是 "[S]",解码器解码出 "[EOS]" 的时候,结束整个解码过程。

词表特殊处理

值得注意的是,词向量可以由网络在训练过程中学习得到,也可以使用其他预先训练的词向量,比如 Word2vec,Glove等。词表一般是根据训练语料统计词频,选择词频大于某个阈值的所有词作为词表,此外,如果训练过程中出现了词表中没有见过的单词,通常用符号 "[UNK]" 来表示。编码器端输入的句子,往往长短不一,在批量训练过程中,需要对句子做预处理,将其处理为长度一样的句子,具体做法就是把所有句子都扩展为最长句子的长度,较短的句子通过在末端不断加入 "[PAD]" 标记来扩展长度,因此需要在词表中额外加入四个特殊标记符: "[S]", "[EOS]", "[UNK]", "[PAD]"。

损失函数

至此,我们已经得到了一个序列生成的过程,还需要设计一个合适的损失函数来训练模型,以得到合适的网络参数。显然,我们的损失函树的目标应当是减少模型输出与真实目标端序列的误差,比较常用的损失函数是交叉熵损失函数,首先定义在 t 时刻的损失函数如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> E t ( y t , y ^ t ) = − y t l o g y ^ t E_t(y_t,\hat{y}_t) = −y_tlog\hat{y}_t </math>Et(yt,y^t)=−ytlogy^t

其中,y_t 表示 t 时刻的标注答案, 是一个只有一个是 1,其他都是 0 的 one-hot 形式的向量;\hat{y}t 是我们预测出来的结果,与 y_t 的维度一样,但它是一个概率向量,表示每个词出现的概率。

每个时刻都需要作出预测,因此整个序列预测结束之后,损失函数为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> E = ∑ t E t ( y t , y ^ t ) = ∑ t − y t l o g y ^ t E = \sum_tE_t(y_t, \hat{y}_t) = \sum_t−y_tlog\hat{y}_t </math>E=∑tEt(yt,y^t)=∑t−ytlogy^t

网络训练的目标就是最小化这个交叉熵函数,可以通过梯度下降 (gradient descent) 算法来更新网络参数。

总结

解码器结构是整个生成的基础,在训练过程中,每步解码都会使用目标端真实的单词作为输入,比如我们已经知道了第 一个要生成的单词是"how",因此我们就可以直接把"how"作为模型的输 入。但是在测试过程中,我们并不知道真实的输入应该是什么,需要设计一种策略解决该问题。

一个简单的策略就是贪心搜索 (greedy search),每个时刻都直接选择预测概 率最大的词作为下一个时刻的输入,贪心搜索算法效率较高, 但是往往无法得到最优解,在实际应用中,往往会采用柱搜索算法,与贪心搜索不同,柱搜索算法每次都保留 K 个概率最大的搜索路径。

下一篇文章,将介绍几种不同的解码方式

相关推荐
nancy_princess4 小时前
clip实验
人工智能·深度学习
南境十里·墨染春水4 小时前
C++传记(面向对象)虚析构函数 纯虚函数 抽象类 final、override关键字
开发语言·c++·笔记·算法
飞哥数智坊4 小时前
TRAE Friends@济南第4次活动:100+极客集结,2小时极限编程燃爆全场!
人工智能
AI自动化工坊4 小时前
ProofShot实战:给AI编码助手添加可视化验证,提升前端开发效率3倍
人工智能·ai·开源·github
飞哥数智坊4 小时前
一场直播涨粉 2 万的背后!OpenClaw + 飞书,正在重塑软件交付的方式
人工智能
2301_797172754 小时前
基于C++的游戏引擎开发
开发语言·c++·算法
飞哥数智坊4 小时前
养虾记第3期:安装、调教、落地,这场沙龙我们全聊了
人工智能
再不会python就不礼貌了4 小时前
从工具到个人助理——AI Agent的原理、演进与安全风险
人工智能·安全·ai·大模型·transformer·ai编程
AI医影跨模态组学4 小时前
Radiother Oncol 空军军医大学西京医院等团队:基于纵向CT的亚区域放射组学列线图预测食管鳞状细胞癌根治性放化疗后局部无复发生存期
人工智能·深度学习·医学影像·影像组学
A尘埃5 小时前
神经网络的激活函数+损失函数
人工智能·深度学习·神经网络·激活函数