DIALOGPT:大规模生成式预训练用于对话响应生成

摘要

我们提出了一个大规模、可调节的神经对话响应生成模型,DIALOGPT(对话生成预训练变换器)。该模型训练于从2005年至2017年间Reddit评论链中提取的1.47亿次类似对话的交流,DIALOGPT扩展了Hugging Face的PyTorch变换器,在单轮对话设置中,无论是自动评估还是人工评估,其表现都接近人类水平。我们展示了利用DIALOGPT的对话系统比强大的基线系统生成更相关、内容更丰富、上下文更一致的响应。预训练模型和训练流程已公开发布,以促进神经响应生成的研究和更智能的开放域对话系统的开发。

1 引言

我们介绍了DIALOGPT,一个可调节的千兆词级别神经网络模型,用于生成对话响应,训练数据来源于Reddit。最近,使用基于变换器架构的大规模预训练取得了巨大的经验成功(Radford等,2018;Devlin等,2019;Raffel等,2019)。例如,OpenAI的GPT-2(Radford等,2018)已经证明,在非常大的数据集上训练的变换器模型可以捕捉文本数据中的长期依赖关系,并生成流畅、词汇多样且内容丰富的文本。这类模型能够以细粒度捕捉文本数据,并产生高分辨率的输出,紧密模仿人类书写的真实世界文本。

DIALOGPT扩展了GPT-2,以应对对话神经响应生成的挑战。神经响应生成是文本生成的一个子类别,其目标是生成与提示相关的自然外观文本(与任何训练实例不同)。然而,建模对话提出了独特的挑战,因为人类对话,其中可能包含两个参与者的竞争目标,本质上在潜在响应的范围内更加多样化(Li等,2016a;Zhang等,2018;Gao等,2019a,b)。因此,它提出了比其他文本生成任务(如神经机器翻译、文本摘要和释义)更典型的一对多问题。人类对话通常也更非正式、嘈杂,并且在文本聊天形式中,经常包含非正式的缩写或句法/词汇错误。

大多数开放域神经响应生成系统存在内容或风格不一致(Li等,2016b;Zhang等,2019;Gao等,2019c)、缺乏长期上下文信息(Serban等,2017)和平淡无奇(Li等,2016a;Zhang等,2018;Qin等,2019)的问题。虽然这些问题可以通过专门设计来提高信息内容的建模策略来缓解,但像GPT-2(Radford等,2018)这样的基于变换器的架构,它使用多层自注意力机制以计算高效的方式允许完全连接的跨注意力到完整上下文,似乎是探索更通用解决方案的自然选择。例如,变换器模型允许长期依赖信息更好地跨时间保存(Radford等,2018),从而提高了内容一致性。由于它们的深层结构(在GPT-2中最多48层),它们也具有更高的模型容量,并且比基于RNN的方法更有效地利用大规模数据集(超过1亿个训练实例)(Vaswani等,2017)。

与GPT-2类似,DIALOGPT被构建为一个自回归(AR)语言模型,并使用多层变换器作为模型架构。然而,与GPT-2不同的是,DIALOGPT是在从Reddit讨论链中提取的大规模对话对/会话上进行训练的。我们的假设是,这应该使DIALOGPT能够以更细的粒度捕捉对话流中的P(目标;源)联合分布。实际上,这正是我们观察到的:由DIALOGPT生成的句子多样化且包含特定于源提示的信息,类似于GPT-2为连续文本生成的内容。我们已经在公共基准数据集(DSTC-7)和一个新的6k多参考测试数据集(从Reddit帖子中提取)上评估了预训练模型。DIALOGPT在自动和人工评估中都达到了最先进的结果,将性能提升到接近人类响应的质量。

我们已经发布了源代码和预训练模型,以促进未来的研究。我们的模型可以轻松地利用和适应新的对话数据集,尤其是训练示例较少的数据集。DIALOGPT包还包含一个基于Huggingface PyTorch变换器(HuggingFace,2019)的开源训练管道(数据提取/准备和模型训练/评估)。

2 数据集

数据集是从2005年至2017年间从Reddit抓取的评论链中提取的。Reddit讨论可以自然地扩展为树结构的回复链,因为回复一个线程的线程形成了后续线程的根节点。我们将从根节点到叶节点的每条路径提取为包含多轮对话的训练实例。

我们通过以下方式过滤数据:移除(1)源或目标中包含URL的实例,(2)目标中包含至少三个单词重复的实例,(3)响应中不包含至少一个前50个最常用英语单词(例如,"the","of","a")的实例,因为这可能表明它可能不是一个英语句子,(4)响应中包含特殊标记如"["或"]"的实例,因为这可能是标记语言,(5)源和目标序列总长度超过200个单词的实例,(6)目标中包含通过短语匹配大型屏蔽列表识别的冒犯性语言的实例。我们还排除了大量被识别为可能包含冒犯性内容的子reddit。此外,我们积极过滤掉平淡无奇的内容,例如,移除响应中包含90%的三元组(这些三元组已被看到超过1000次)的实例。通常这些响应信息量不大,约占数据的1%。

过滤后,数据集包含147,116,725个对话实例,总计18亿个单词。

3 方法

3.1 模型架构

我们在GPT-2(Radford等,2018)架构的基础上训练了DIALOGPT模型。GPT-2变换器模型采用通用变换器语言模型(Vaswani等,2017),并利用一堆掩码多头自注意力层来训练大规模网络文本数据。从头开始或基于用户特定提示生成的文本看起来都很真实。GPT-2的成功表明,变换器语言模型能够在细粒度水平上刻画人类语言数据分布,这可能是由于大模型容量和卓越的效率。

我们的模型继承了GPT-2(Radford等,2018)的特性,这是一个具有层归一化的12至48层变换器,我们修改了考虑模型深度的初始化方案,并为分词器使用了字节对编码(Sennrich等,2016)。我们遵循OpenAI GPT-2的方法,将多轮对话会话建模为长文本,并将生成任务框架化为语言建模。我们首先将对话会话中的所有对话轮次连接成一个长文本 x 1 , ⋯   , x N x_1,\cdots,x_N x1,⋯,xN( N N N是序列长度),以文本结束标记结尾。我们将源句子(对话历史)表示为 S = x 1 , ⋯   , x m S=x_{1},\cdots,x_{m} S=x1,⋯,xm,目标句子(真实响应)表示为 T = x m + 1 , ⋯   , x N T=x_{m+1},\cdots,x_{N} T=xm+1,⋯,xN,条件概率 P ( T ∣ S ) P(T|S) P(T∣S)可以写为一系列条件概率的乘积:
p ( T ∣ S ) = ∏ n = m + 1 N p ( x n ∣ x 1 , ⋯   , x n − 1 ) (1) p(T|S)=\prod_{n=m+1}^Np(x_n|x_1,\cdots,x_{n-1})\tag{1} p(T∣S)=n=m+1∏Np(xn∣x1,⋯,xn−1)(1)

对于一个多轮对话会话 T 1 , ⋯   , T K T_1,\cdots,T_K T1,⋯,TK,(1)可以写成 p ( T K , ⋯   , T 2 ∣ T 1 ) p(T_{K},\cdots,T_{2}|T_{1}) p(TK,⋯,T2∣T1),这本质上是条件概率 p ( T i ∣ T i − 1 , ⋯   , T 1 ) p(T_i|T_{i-1},\cdots,T_1) p(Ti∣Ti−1,⋯,T1)的乘积。因此,优化单一目标 p ( T K , ⋯   , T 2 ∣ T 1 ) p(T_K,\cdots,T_2|T_1) p(TK,⋯,T2∣T1)可以被视为优化所有 p ( T i ∣ T i − 1 , ⋯   , T 1 ) p(T_i|T_{i-1},\cdots,T_1) p(Ti∣Ti−1,⋯,T1)源-目标对。我们的实现基于开源的PyTorch-transformer仓库。

3.2 互信息最大化

开放域文本生成模型因生成平淡、信息量不足的样本而闻名。为了解决这个问题,我们实现了一种最大互信息(MMI)评分函数(Li等,2016a;Zhang等,2018)。MMI使用预训练的反向模型从给定响应中预测源句子,即 P ( Source ∣ Target ) P(\text{Source}|\text{Target}) P(Source∣Target)。我们首先使用Top-K采样生成一组假设,然后利用 P ( Source ∣ Hypothesis ) P(\text{Source}|\text{Hypothesis}) P(Source∣Hypothesis)的概率对所有假设进行重新排序。直观上,最大化反向模型似然会惩罚平淡的假设,因为频繁和重复的假设可能与许多可能的查询相关联,从而为任何特定查询产生较低的概率。

我们还尝试使用策略梯度(Williams,1992)和样本平均基线优化奖励 R = P ( Source ∣ Hypothesis ) R = P(\text{Source}|\text{Hypothesis}) R=P(Source∣Hypothesis),遵循Zhang等(2018)的方法。验证奖励可以稳定提高,但与RNN架构下的训练不同,我们观察到强化学习(RL)训练容易收敛到一个退化的局部最优解,即假设简单地重复源句子(即鹦鹉学舌模型),此时互信息被最大化。我们推测,由于变换器强大的模型表示能力,它们可能会陷入局部最优。我们将正则化RL训练的研究留给未来的工作。

4 结果

4.1 实验细节

我们训练了三种不同规模的模型,总参数分别为117M、345M和762M。模型规格遵循Radford等(2018)(表1)。

我们的模型使用了一个包含50,257个条目的词汇表,并在16台配备NVLink的Nvidia V100机器上进行了训练。我们使用了Noam学习率调度器,预热步数为16000。学习率根据验证损失选择。每个模型训练到验证损失不再有进展为止。对于小型和中型模型,我们训练了最多5个周期;对于大型模型,我们最多训练了3个周期。

加速训练

为了加速训练过程并适应GPU内存限制,我们首先将所有训练数据压缩到一个惰性加载的数据库文件中,以便仅在需要时加载数据(预取大块数据以减少访问频率)。我们还利用单独的异步数据进程来扩展训练。结果,训练时间随GPU数量的增加近似线性下降。我们进一步采用了动态批处理策略,将长度相似的对话分组到同一批次中,从而提高了训练吞吐量。

4.2 DSTC-7 对话生成挑战

DSTC(对话系统技术挑战)第7赛道(Galley等,2019)是一个端到端的对话建模任务,其目标是通过注入基于外部知识的信息生成超越闲聊的对话响应。该任务与通常认为的目标导向、任务导向或任务完成对话不同,因为它没有特定或预定义的目标(例如预订航班或在餐厅订桌)。相反,它针对的是类似人类的互动,其中潜在目标通常是模糊的或事先未知的,类似于工作和其他生产环境(例如头脑风暴会议)中人们共享信息的情况。

DSTC-7测试数据包含来自Reddit的对话线程。为了创建一个多参考测试集,我们利用了包含6个或更多响应的对话会话。结合其他过滤标准(如轮次长度),最终生成了一个包含2208个实例的5参考测试集。(对于每个实例,6个人类响应中的一个被留出以评估人类在此任务上的表现。)需要注意的是,我们的训练数据与测试集的时间跨度不同。

我们使用标准的机器翻译指标进行自动评估,包括BLEU(Papineni等,2002)、METEOR(Lavie和Agarwal,2007)和NIST(Doddington,2002)。NIST是BLEU的变体,通过信息增益对n-gram匹配进行加权,即间接惩罚信息量不足的n-gram。我们还使用熵(Zhang等,2018)和Dist-n(Li等,2016a)来评估词汇多样性。更多细节见Galley等(2019)。

我们将DIALOGPT与基于(Li等,2016a)的内部竞争性序列到序列模型PERSONALITYCHAT进行了比较,该模型在Twitter数据上训练,并已作为微软Azure的认知服务投入生产。表2总结了自动评估结果。具有345M参数并使用束搜索的DIALOGPT在大多数指标上取得了最高的自动评分。345M参数的DIALOGPT在所有指标上均优于117M参数。束搜索(束宽为10)显著提高了BLEU和DIST分数,并略微提高了NIST和METEOR分数。需要注意的是,我们的模型在源-目标对上进行了微调,并未利用DSTC训练集中的基础信息。推测模型在预训练期间学习了背景信息,因此不受缺乏基础文档的限制。

DIALOGPT的自动评分高于人类评分。这并不意味着生成的内容比人类更"真实",而可能是由于对话的一对多性质。如图1所示,多个人类响应(R1-R4)可以很好地对应一个源话语。在不失一般性的情况下,假设R1-R3是用于测试的"真实参考",而R4是用于计算"人类"评分的"保留"人类响应。在语义空间中,训练良好的模型生成的响应Rg可能会倾向于位于所有可能响应的几何中心附近,因为训练目标是生成最可能的响应。这可能接近所有训练实例的几何平均值,从而"平均化"这些实例。因此,生成的响应Rg与R1-R3的"语义距离"(表现为更高的自动评分,如BLEU)可能比目标人类响应R4更小。

4.3 新的Reddit多参考数据集

我们进一步在包含6K个实例的多参考测试集上评估了DIALOGPT。结果如表3所示。我们在两种设置下测试了我们的方法:从头开始训练和使用GPT-2作为预训练模型进行微调。在这两种设置中,较大的模型始终优于较小的模型。比较从头开始训练与基于预训练GPT-2模型的微调,当应用于较小模型时,使用GPT-2模型带来了更大的性能提升。同样,最佳系统DIALOGPT(345M,带束搜索)在BLEU上的得分高于人类。从头开始训练的较大模型(345M和762M)与基于GPT-2微调的模型表现相当。

4.4 使用MMI重新排序响应

我们按照第3.2节所述执行互信息最大化。具体来说,我们使用基于GPT-2中型模型微调的345M模型,通过Top-K采样(K=10)为每个输入源句子生成16个样本。然后使用反向模型进行重新排序,该反向模型也是基于GPT-2中型模型微调的345M模型。选择反向模型损失最低的响应进行评估。结果总结在表3的倒数第二行中。可以看出,与贪婪生成相比,MMI重新排序生成的响应更加多样化,具有更高的NIST、METEOR、熵和Dist分数,但BLEU略有下降。

4.5 生成示例

我们在表4(交互式聊天)和表5(带有用户提示的自播放机器人)中提供了生成的对话示例。输出基于Top-K采样。有趣的是,我们的模型在一定程度上展示了回答常识性问题的能力,这可能是由于从Reddit数据中可以学习到丰富的信息。在某些情况下,系统并未给出"期望"的答案,而是生成了一个替代的合理答案。我们的观察表明,系统能够比RNN模型更好地处理多轮生成,并且在上下文一致性方面表现更佳(表5)。

4.6 人工评估

我们从Reddit 6K测试数据集中随机抽取了2000个测试样本,并通过众包进行了人工评估。系统被配对,每对系统的输出随机呈现给3位评审员,评审员根据相关性、信息量以及生成内容的人类相似性,使用3点Likert量表对其进行评分。评审员需要通过资格测试,并且实施了垃圾检测机制。总体评审员对相关性、信息量和人类相似性的偏好以原始数字和占总数的百分比形式呈现在表7中。可以明显看出,DialoGPT相较于PersonalityChat更受青睐。表7还表明,"普通"的DialoGPT中型模型可能已经接近人类响应的质量。出乎意料的是,我们发现评审员可能更喜欢MMI变体而非人类响应,这可能是因为许多真实的人类响应是随意或独特的,或者与评审员不熟悉的网络梗相关。(有关此效应的背景条件,请参见第4.2节。)更多细节,包括显著性测试和使用的人工评估模板,请参见附录。

5 相关工作

目前已有多个开源工具包用于大规模预训练变换器模型。Huggingface的Conv-AI迁移学习仓库(Wolf等,2019)包含基于GPT-2变换器语言模型的对话AI系统迁移学习代码,该模型在ConvAI-2对话竞赛中取得了最先进的性能。DLGnet(Olabiyi和Mueller,2019)是一个在对话数据集上训练的大型变换器模型,在多轮对话生成中表现良好。AllenNLP(Gardner等,2018)是一个用于多种自然语言处理任务的工具包,包括大规模预训练的双向LSTM句子表示学习框架ELMo(Peters等,2018)。Texar(Hu等,2018)专注于文本生成,包括风格迁移和可控生成,并结合了强化学习能力和序列建模工具。DeepPavlov(Burtsev等,2018)是一个专注于任务导向对话的流行框架,包含多个问答和情感分类的演示和预训练模型。Icecaps(Shiv等,2019)是一个响应生成工具包,支持基于个性或外部知识的对话生成以及多任务训练。ConvAI2挑战赛(Dinan等,2019)专注于个性化对话。ParlAI(Miller等,2017)是另一个用于开发任务导向对话系统的库,包含基于众包数据训练的知识驱动聊天机器人预训练模型。Text-to-Text Transformer(Raffel等,2019)统一了多种文本建模任务,并在各种自然语言生成和理解基准测试中取得了最先进的成果。

6 局限性与风险

DIALOGPT仅作为模型发布,解码器的实现责任由用户承担。尽管我们在训练前努力减少明显冒犯性数据,DIALOGPT仍有可能生成可能引发冒犯的输出。输出可能反映数据中隐含的性别和其他历史偏见。使用该模型生成的响应可能倾向于表达对不道德、偏见或冒犯性命题的同意(或反之,对道德命题的反对)。这些是当前基于大规模自然数据集训练的最先进端到端对话模型的已知问题。发布DIALOGPT的主要动机是使研究人员能够研究这些问题并开发缓解策略。无论如何,使用DIALOGPT生成的不当内容不应被视为作者或微软公司的观点或价值观的反映。

7 结论

我们发布了一个基于大规模真实世界Reddit数据集训练的开放域预训练模型DIALOGPT。该工具包包含分布式训练管道和多个预训练模型,可以在几小时内对中等规模的自定义数据集进行微调以获得对话模型。DIALOGPT完全开源且易于部署,允许用户扩展预训练对话系统以使用各种数据集进行引导训练。它作为构建新颖应用和方法的基础模块。检测和控制有害输出将是未来研究的重点。我们将研究利用强化学习进一步提高生成响应的相关性,并防止模型生成严重不当的响应。

A 人工评估的附加细节

均值差异的显著性测试通过10,000次自举迭代进行。P值在α = 0.05的水平下计算。结果如表8所示。345M(2)和762M(6)模型之间的差异不显著。值得注意的是,345M模型(2)与人类响应(1)之间的差异在统计上也不显著。人工评估的模板如图2所示。

相关推荐
码码哈哈0.01 小时前
「拼好帧」小黄鸭 Lossless Scaling 软件介绍与下载
人工智能
LittleNyima2 小时前
【代码解读】阿里最新开源视频生成模型 Wan 2.1 实现解析
人工智能·stable diffusion·aigc·音视频
张焚雪4 小时前
关于卷积神经网络的一份介绍
人工智能·深度学习·神经网络·cnn
楼台的春风4 小时前
【图像的读写与基本操作】
图像处理·人工智能·深度学习·opencv·算法·计算机视觉·嵌入式
唔皇万睡万万睡4 小时前
基于帧差分法的车辆检测系统(Matlab)
人工智能·计算机视觉·目标跟踪
程序员古德4 小时前
《论边缘计算及其应用》审题技巧 - 系统架构设计师
人工智能·边缘计算·边云协同·软件项目·设计实现·论述框架
寻道码路5 小时前
深度剖析 Video-RAG:厦门大学和罗切斯特大学联合推出的一种用于长视频理解的检索增强生成技术
人工智能·语言模型·开源·aigc·音视频·ai编程
一只蜗牛儿5 小时前
Sherpa-ONNX:说话人识别与语音识别自动开启(VAD)+ Python API 完整指南
人工智能·python·语音识别
结衣结衣.5 小时前
【OpenCV】入门教学
图像处理·人工智能·python·opencv