文本生成:解码器

最早期的生成任务,是大家所熟知的 "编码器-解码器结构", 简单来说,就是将输入的序列信息编码,由解码器对编码后的信息作出解释。这种结构已经在机器翻译、文本摘要、复述生 成等多种生成任务上得到应用,并取得了不错的效果。其本质就是将源语言序列映射到目标语言的词序列,可以被看做是一个以源端句子为条件,来建模目标端句子的语言模型。

解码器结构

如图所示,输入的是已经分好词的源端句子"what is zakat used for"。编码器接收到输入之后,首先会根据词表将词映射到对应的词嵌入表示 (word embedding),于是就得到了网络的输入 {x1, ..., xl},输入表示经过一个循环神经网络,得到句子表示 h,公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> h t = f ( x t , h t − 1 ) , h_t = f (x_t,h_{t−1}), </math>ht=f(xt,ht−1), t = 1,2,...,l and h0 = 0

编码器得到的表示,可以作为解码器的初始化状态,也就是第一个时刻的一个输入。解码器还需要输出层来做预测,预测下一个单词的生成概率,并作出决策。由于解码器每一个时刻都需要接受前一时刻的隐层和当前输入的词表示来计算当前时刻的隐层表示,并用输出层得到下一个要生成的单词的概率,因此,需要对输入的句子做一些标记,来告诉解码器什么时候开始解码,生成第一个单词,什么时候停止解码,也就是说需要标记生成的开始和结束。一般,在输入句子的开始和结束位置加上两个标记,"[S]"表示句子的开始,"[EOS]"表示句子结束。 如下公式展示了在 t 时刻解码器生成单词的计算过程:

<math xmlns="http://www.w3.org/1998/Math/MathML"> p ( y ) = s o f t m a x ( V h t + b ) p(y)=softmax(Vht+b) </math>p(y)=softmax(Vht+b)

其中 V 是权重矩阵,b 是偏置向量,softmax 是归一化指数函数,将网络输出值归一化成目标端输出词的概率。softmax函数的计算公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> s o f t m a x ( o k ) = e x p ( o k ) / ∑ k ′ = 1 K e x p ( o k ′ ) softmax(o^k)= exp(o^k)/\sum_{k'=1}^{K} exp(o^{k′}) </math>softmax(ok)=exp(ok)/∑k′=1Kexp(ok′)

解码第一个单词的时候,接收的输入是 "[S]",解码器解码出 "[EOS]" 的时候,结束整个解码过程。

词表特殊处理

值得注意的是,词向量可以由网络在训练过程中学习得到,也可以使用其他预先训练的词向量,比如 Word2vec,Glove等。词表一般是根据训练语料统计词频,选择词频大于某个阈值的所有词作为词表,此外,如果训练过程中出现了词表中没有见过的单词,通常用符号 "[UNK]" 来表示。编码器端输入的句子,往往长短不一,在批量训练过程中,需要对句子做预处理,将其处理为长度一样的句子,具体做法就是把所有句子都扩展为最长句子的长度,较短的句子通过在末端不断加入 "[PAD]" 标记来扩展长度,因此需要在词表中额外加入四个特殊标记符: "[S]", "[EOS]", "[UNK]", "[PAD]"。

损失函数

至此,我们已经得到了一个序列生成的过程,还需要设计一个合适的损失函数来训练模型,以得到合适的网络参数。显然,我们的损失函树的目标应当是减少模型输出与真实目标端序列的误差,比较常用的损失函数是交叉熵损失函数,首先定义在 t 时刻的损失函数如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> E t ( y t , y ^ t ) = − y t l o g y ^ t E_t(y_t,\hat{y}_t) = −y_tlog\hat{y}_t </math>Et(yt,y^t)=−ytlogy^t

其中,y_t 表示 t 时刻的标注答案, 是一个只有一个是 1,其他都是 0 的 one-hot 形式的向量;\hat{y}t 是我们预测出来的结果,与 y_t 的维度一样,但它是一个概率向量,表示每个词出现的概率。

每个时刻都需要作出预测,因此整个序列预测结束之后,损失函数为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> E = ∑ t E t ( y t , y ^ t ) = ∑ t − y t l o g y ^ t E = \sum_tE_t(y_t, \hat{y}_t) = \sum_t−y_tlog\hat{y}_t </math>E=∑tEt(yt,y^t)=∑t−ytlogy^t

网络训练的目标就是最小化这个交叉熵函数,可以通过梯度下降 (gradient descent) 算法来更新网络参数。

总结

解码器结构是整个生成的基础,在训练过程中,每步解码都会使用目标端真实的单词作为输入,比如我们已经知道了第 一个要生成的单词是"how",因此我们就可以直接把"how"作为模型的输 入。但是在测试过程中,我们并不知道真实的输入应该是什么,需要设计一种策略解决该问题。

一个简单的策略就是贪心搜索 (greedy search),每个时刻都直接选择预测概 率最大的词作为下一个时刻的输入,贪心搜索算法效率较高, 但是往往无法得到最优解,在实际应用中,往往会采用柱搜索算法,与贪心搜索不同,柱搜索算法每次都保留 K 个概率最大的搜索路径。

下一篇文章,将介绍几种不同的解码方式

相关推荐
夜雨飘零13 分钟前
基于Pytorch实现的说话人日志(说话人分离)
人工智能·pytorch·python·声纹识别·说话人分离·说话人日志
菌菌的快乐生活16 分钟前
理解支持向量机
算法·机器学习·支持向量机
爱喝热水的呀哈喽19 分钟前
《机器学习》支持向量机
人工智能·决策树·机器学习
大山同学21 分钟前
第三章线性判别函数(二)
线性代数·算法·机器学习
minstbe23 分钟前
AI开发:使用支持向量机(SVM)进行文本情感分析训练 - Python
人工智能·python·支持向量机
月眠老师26 分钟前
AI在生活各处的利与弊
人工智能
axxy200040 分钟前
leetcode之hot100---240搜索二维矩阵II(C++)
数据结构·算法
四口鲸鱼爱吃盐41 分钟前
Pytorch | 从零构建MobileNet对CIFAR10进行分类
人工智能·pytorch·分类
苏言の狗42 分钟前
Pytorch中关于Tensor的操作
人工智能·pytorch·python·深度学习·机器学习
黑客Ash1 小时前
安全算法基础(一)
算法·安全