文本生成:解码器

最早期的生成任务,是大家所熟知的 "编码器-解码器结构", 简单来说,就是将输入的序列信息编码,由解码器对编码后的信息作出解释。这种结构已经在机器翻译、文本摘要、复述生 成等多种生成任务上得到应用,并取得了不错的效果。其本质就是将源语言序列映射到目标语言的词序列,可以被看做是一个以源端句子为条件,来建模目标端句子的语言模型。

解码器结构

如图所示,输入的是已经分好词的源端句子"what is zakat used for"。编码器接收到输入之后,首先会根据词表将词映射到对应的词嵌入表示 (word embedding),于是就得到了网络的输入 {x1, ..., xl},输入表示经过一个循环神经网络,得到句子表示 h,公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> h t = f ( x t , h t − 1 ) , h_t = f (x_t,h_{t−1}), </math>ht=f(xt,ht−1), t = 1,2,...,l and h0 = 0

编码器得到的表示,可以作为解码器的初始化状态,也就是第一个时刻的一个输入。解码器还需要输出层来做预测,预测下一个单词的生成概率,并作出决策。由于解码器每一个时刻都需要接受前一时刻的隐层和当前输入的词表示来计算当前时刻的隐层表示,并用输出层得到下一个要生成的单词的概率,因此,需要对输入的句子做一些标记,来告诉解码器什么时候开始解码,生成第一个单词,什么时候停止解码,也就是说需要标记生成的开始和结束。一般,在输入句子的开始和结束位置加上两个标记,"[S]"表示句子的开始,"[EOS]"表示句子结束。 如下公式展示了在 t 时刻解码器生成单词的计算过程:

<math xmlns="http://www.w3.org/1998/Math/MathML"> p ( y ) = s o f t m a x ( V h t + b ) p(y)=softmax(Vht+b) </math>p(y)=softmax(Vht+b)

其中 V 是权重矩阵,b 是偏置向量,softmax 是归一化指数函数,将网络输出值归一化成目标端输出词的概率。softmax函数的计算公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> s o f t m a x ( o k ) = e x p ( o k ) / ∑ k ′ = 1 K e x p ( o k ′ ) softmax(o^k)= exp(o^k)/\sum_{k'=1}^{K} exp(o^{k′}) </math>softmax(ok)=exp(ok)/∑k′=1Kexp(ok′)

解码第一个单词的时候,接收的输入是 "[S]",解码器解码出 "[EOS]" 的时候,结束整个解码过程。

词表特殊处理

值得注意的是,词向量可以由网络在训练过程中学习得到,也可以使用其他预先训练的词向量,比如 Word2vec,Glove等。词表一般是根据训练语料统计词频,选择词频大于某个阈值的所有词作为词表,此外,如果训练过程中出现了词表中没有见过的单词,通常用符号 "[UNK]" 来表示。编码器端输入的句子,往往长短不一,在批量训练过程中,需要对句子做预处理,将其处理为长度一样的句子,具体做法就是把所有句子都扩展为最长句子的长度,较短的句子通过在末端不断加入 "[PAD]" 标记来扩展长度,因此需要在词表中额外加入四个特殊标记符: "[S]", "[EOS]", "[UNK]", "[PAD]"。

损失函数

至此,我们已经得到了一个序列生成的过程,还需要设计一个合适的损失函数来训练模型,以得到合适的网络参数。显然,我们的损失函树的目标应当是减少模型输出与真实目标端序列的误差,比较常用的损失函数是交叉熵损失函数,首先定义在 t 时刻的损失函数如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> E t ( y t , y ^ t ) = − y t l o g y ^ t E_t(y_t,\hat{y}_t) = −y_tlog\hat{y}_t </math>Et(yt,y^t)=−ytlogy^t

其中,y_t 表示 t 时刻的标注答案, 是一个只有一个是 1,其他都是 0 的 one-hot 形式的向量;\hat{y}t 是我们预测出来的结果,与 y_t 的维度一样,但它是一个概率向量,表示每个词出现的概率。

每个时刻都需要作出预测,因此整个序列预测结束之后,损失函数为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> E = ∑ t E t ( y t , y ^ t ) = ∑ t − y t l o g y ^ t E = \sum_tE_t(y_t, \hat{y}_t) = \sum_t−y_tlog\hat{y}_t </math>E=∑tEt(yt,y^t)=∑t−ytlogy^t

网络训练的目标就是最小化这个交叉熵函数,可以通过梯度下降 (gradient descent) 算法来更新网络参数。

总结

解码器结构是整个生成的基础,在训练过程中,每步解码都会使用目标端真实的单词作为输入,比如我们已经知道了第 一个要生成的单词是"how",因此我们就可以直接把"how"作为模型的输 入。但是在测试过程中,我们并不知道真实的输入应该是什么,需要设计一种策略解决该问题。

一个简单的策略就是贪心搜索 (greedy search),每个时刻都直接选择预测概 率最大的词作为下一个时刻的输入,贪心搜索算法效率较高, 但是往往无法得到最优解,在实际应用中,往往会采用柱搜索算法,与贪心搜索不同,柱搜索算法每次都保留 K 个概率最大的搜索路径。

下一篇文章,将介绍几种不同的解码方式

相关推荐
NAGNIP11 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab12 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab12 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP16 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年16 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼16 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS16 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区18 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈18 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang18 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx