论文阅读笔记-Get To The Point: Summarization with Pointer-Generator Networks

前言

最近看2021ACL的文章,碰到了Copying Mechanism和Coverage mechanism两种技巧,甚是感兴趣的翻阅了原文进行阅读,Copying Mechanism的模型CopyNet已经进行阅读并写了阅读笔记,如下:
论文阅读笔记:Copying Mechanism缓解未登录词问题的模型--CopyNet

而本篇文章则是讲Coverage mechanism,当然这篇并不是Coverage mechanism最初的技巧原文(最早出现在这一篇:Statistical machine translation),本篇只是将这个技巧进行改进使其更加适用于RNN-base的Seq2Seq模型。Copying Mechanism和Coverage mechanism两个技巧的提出都比较早,但是其应用得当,在特定任务上给模型带来效果提升会令人意想不到。

本论文主要围绕解决Seq2Seq模型应用于摘要生成时主要存在两个问题:

  • 难以准确复述原文的事实细节、无法处理原文中的未登录词(OOV)
  • 生成的摘要中存在重复的片段

对于OOV的问题,一种很自然的想法就是将source doc也纳入输出词的考虑范围,即可以直接从source doc中复制相关相应的token作为输出,这一点在CopyNet中应用的效果很不错。而对于重复词的问题,需要通过一种手段,利用之前所生成的token来影响当前time step的决策(可以认为是已出现的token概率进行惩罚),从而避免产生重复词,不过论文作者为了避免影响模型的效果,对不同的模型任务进行了改进,比如额外加上了coverage loss来将token位置也给考虑进去。

模型细节

encoder部分采用一个单层双向LSTM,输入原文的词向量序列,输出一个编码后的隐层状态序列 h i h_i hi。decoder部分采用一个单层单向LSTM,每一步的输入是前一步预测的词的词向量,同时输出一个解码的状态序列 s t s_t st,用于当前步的预测。attention具体的计算公式为:
e i t = v T t a n h ( W h h i + W s s t + b a t t n ) e_i^t=v^Ttanh(W_hh_i+W_ss_t+b_{attn}) eit=vTtanh(Whhi+Wsst+battn)
a t = s o f t m a x ( e t ) a_t=softmax(e_t) at=softmax(et)

其中 h i , s t h_i,s_t hi,st分别是source doc进行双向LSTM编码的hidden state和cell state, W , b W,b W,b则是参数。在计算出当前步的attention分布后,对encoder输出的隐层做加权平均,获得输入序列的动态表示,即context-vector:
h t ∗ = ∑ i a i t h i h_t^*=\sum_ia_i^th_i ht∗=i∑aithi

在不使用Copy Mechanism的情况下,我们的Seq2Seq是依靠decoder输出的隐层和context-vector,共同决定当前time step预测在词表上的概率分布:
P v o c a b = s o f t m a x ( V ′ ( V [ s t , h t ∗ ] + b ) + b ′ ) P_{vocab}=softmax(V^{'}(V[s_t,h_t^*]+b)+b^{'}) Pvocab=softmax(V′(V[st,ht∗]+b)+b′)

Copying Mechanism

而论文则是在预测的每一个time step,通过动态计算一个生成概率 p g e n p_{gen} pgen,巧妙的把seq2seq模型和pointer network结合起来,使得即保留了seq2seq模型保持抽象生成的能力,也保留了pointer network直接从原文中取词的Copy能力:
p g e n = σ ( w h ∗ T h t ∗ + w s T s t + w x T + b p t r ) p_{gen}=\sigma(w_{h^*}^Th_t^*+w_s^Ts_t+w_x^T+b_{ptr}) pgen=σ(wh∗Tht∗+wsTst+wxT+bptr)
P ( W ) = p g e n P v o c a b ( w ) + ( 1 − p g e n ) ∑ i : w i a i t P(W)=p_{gen}P_{vocab}(w)+(1-p_{gen})\sum_{i:w_i}a_i^t P(W)=pgenPvocab(w)+(1−pgen)i:wi∑ait

其中, σ \sigma σ 是sigmoid激活函数,这样就直接把seq2seq模型计算的attention分布作为pointer network的输出,源代码实现上通过参数复用,大大降低了模型的复杂度,如下:

复制代码
with tf.variable_scope('calculate_pgen'):
p_gen = linear([context_vector, state.c, state.h, x], 1, True) # Tensor shape (batch_size, 1)
p_gen = tf.sigmoid(p_gen)
p_gens.append(p_gen)
Coverage mechanism

除此之外,针对重复词问题,论文使用Coverage mechanism,Coverage模型的重点在于预测过程中,维护一个coverage vector:
c t = ∑ t ′ = 0 t − 1 a t ′ c^t=\sum_{t^{'}=0}^{t-1}a^{t^{'}} ct=t′=0∑t−1at′

这个向量是过去所有预测步计算的attention分布的累加和,记录着模型已经关注过source doc的哪些token,并且让这个coverage vector影响当前time step的attention计算:
e i t = v T t a n h ( W h h i + W s s t + w c c i t + b a t t n ) e_i^t=v^Ttanh(W_hh_i+W_ss_t+w_cc_i^t+b_{attn}) eit=vTtanh(Whhi+Wsst+wccit+battn)

这样做的目的在于,在模型进行当前time step进行attention计算的时候,告诉它之前它已经关注过的token,希望避免出现连续attention到某几个token上的情况。同时,考虑到重复token的位置的影响,coverage模型还添加一个额外的coverage loss,来对重复的attention作惩罚:
c o v l o s s t = ∑ i m i n ( a i t , c i t ) covloss_t=\sum_imin(a_i^t,c_i^t) covlosst=i∑min(ait,cit)

这样这个loss只会对重复的attention产生惩罚,并不会强制要求模型关注原文中的每一个词。加上词表预测的损失函数采用交叉熵:
l o s s = − 1 T ∑ t = 0 T l o g P ( w t ∗ ) loss=-\frac{1}{T}\sum_{t=0}^TlogP(w_t^*) loss=−T1t=0∑TlogP(wt∗)

最终,模型的整体损失函数为:
l o s s t = − l o g P ( w t ∗ ) + λ ∑ i m i n ( a i t , c i t ) loss_t=-logP(w_t^*)+\lambda\sum_imin(a_i^t,c_i^t) losst=−logP(wt∗)+λi∑min(ait,cit)

文章在实验部分提到,如果移除了covloss,单纯依靠coverage vector去影响attention的计算并不能缓解重复token的问题,模型还是会重复地attention到某些token上。而加上covloss的模型训练上也比较trick,需要先用主函数训练好一个收敛的模型,然后再把covloss加上,做个finetune,不然的话效果还是不好。

实验结果

论文用的数据集是CNN/DailyMail数据集,可以看到论文的模型在该任务上有着明显的提升。

下面是三种模型对同一篇原文生成的摘要,橘色的是最终coverage vector在原文上的分布,红色的是事实细节和OOV问题,绿色的是生成摘要时 p g e n p_{gen} pgen 的大小。

总结


本文模型改善了抽象文本摘要中存在的主要问题,但与具象摘要结果相比仍然存在差距,同时考虑到新闻文章重要信息普遍集中分布于前部分的特性,抽象摘要模型受到了一定影响,模型的普适性需要进一步地验证。

相关推荐
玄同7651 小时前
从 0 到 1:用 Python 开发 MCP 工具,让 AI 智能体拥有 “超能力”
开发语言·人工智能·python·agent·ai编程·mcp·trae
小瑞瑞acd1 小时前
【小瑞瑞精讲】卷积神经网络(CNN):从入门到精通,计算机如何“看”懂世界?
人工智能·python·深度学习·神经网络·机器学习
火车叼位2 小时前
也许你不需要创建.venv, 此规范使python脚本自备依赖
python
火车叼位2 小时前
脚本伪装:让 Python 与 Node.js 像原生 Shell 命令一样运行
运维·javascript·python
孤狼warrior2 小时前
YOLO目标检测 一千字解析yolo最初的摸样 模型下载,数据集构建及模型训练代码
人工智能·python·深度学习·算法·yolo·目标检测·目标跟踪
Katecat996632 小时前
YOLO11分割算法实现甲状腺超声病灶自动检测与定位_DWR方法应用
python
玩大数据的龙威2 小时前
农经权二轮延包—各种地块示意图
python·arcgis
ZH15455891312 小时前
Flutter for OpenHarmony Python学习助手实战:数据库操作与管理的实现
python·学习·flutter
belldeep3 小时前
python:用 Flask 3 , mistune 2 和 mermaid.min.js 10.9 来实现 Markdown 中 mermaid 图表的渲染
javascript·python·flask
喵手3 小时前
Python爬虫实战:电商价格监控系统 - 从定时任务到历史趋势分析的完整实战(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·电商价格监控系统·从定时任务到历史趋势分析·采集结果sqlite存储