保留网络:大型语言模型的Transformer继任者

原文信息

原文题目:《Retentive Network: A Successor to Transformer for Large Language Models》

原文引用:Sun Y, Dong L, Huang S, et al. Retentive Network: A Successor to Transformer for Large Language Models[J]. arXiv preprint arXiv:2307.08621, 2023.

原文链接:https://arxiv.org/pdf/2307.08621.pdfhttps://arxiv.org/pdf/2307.08621.pdf

0.摘要

在这项工作中,我们提出了Retentive Network (RETNET)作为大型语言模型的基础架构,同时实现了训练并行性、低成本推理和良好的性能。我们从理论上推导了循环和注意力之间的联系。然后,我们提出了序列建模的保留机制,支持三种计算范式,即并行、递归和分块递归。具体而言,并行表示允许进行训练并行化。递归表示实现了低成本的O(1)推理,提高了解码吞吐量、延迟和GPU内存,而不会牺牲性能。分块递归表示以线性复杂度实现了高效的长序列建模,每个块都是并行编码的同时递归总结块。语言模型的实验结果表明,RETNET实现了有利的扩展结果、并行训练、低成本部署和高效推理。这些有趣的特性使RETNET成为大型语言模型中Transformer的强有力的继任者。代码将在https://aka.ms/retnet上提供。

图1:相较于Transformer,保留网络(RetNet)在GPU内存、吞吐量和延迟等方面实现了低成本推理、训练并行性和良好的缩放曲线。推理成本的结果以输入长度为8k进行报告。图6展示了不同序列长度的更多结果。

1.引言

Transformer[VSP+17]已成为大型语言模型[BMR+20]的事实上的架构,最初是为了解决循环模型[HS97]的顺序训练问题而提出的。然而,Transformer的训练并行性是以低效推理为代价的,因为每个步骤的复杂度为O(N),而且受限于内存的键值缓存[Sha19],这使得Transformer在部署方面不友好。不断增长的序列长度增加了GPU内存消耗,同时也增加了延迟,并降低了推理速度。

许多努力不断致力于开发下一代架构,旨在保持与Transformer相当的训练并行性和竞争性能,同时具有高效的O(1)推理。同时实现上述目标是具有挑战性的,即所谓的"不可能三角形",如图2所示。

已经有三个主要的研究方向。首先,线性化注意力[KVPF20]使用核函数ϕ(q) · ϕ(k)来近似标准的注意力分数exp(q · k),以便将自回归推理重写为循环形式。然而,模型的建模能力和性能较Transformer差,这阻碍了该方法的普及。第二个方向是回归模型,用于高效推理,但牺牲了训练并行性。为了弥补这一点,使用逐元素运算符[PAA+23]进行加速,但会损害表示能力和性能。第三个研究方向是探索用其他机制替代注意力,例如S4[GGR21]及其变体[DFS+22,PMN+23]。然而,以前的工作都无法突破"不可能三角形",与Transformer相比没有明确的胜出者。

在这项工作中,我们提出了保留网络(RetNet),实现了低成本推理、高效的长序列建模、与Transformer可比的性能以及并行模型训练。具体而言,我们引入了一种多尺度保留机制来替代多头注意力,它具有三种计算范式,即并行、循环和分块循环表示。首先,通过并行表示,我们能够充分利用GPU设备进行训练并行性。其次,通过循环表示,我们能够在内存和计算方面实现高效的O(1)推理,从而大大降低了部署成本和延迟。此外,实现过程中无需使用键值缓存技巧,大大简化了实现。第三,分块循环表示可以进行高效的长序列建模。我们同时对每个本地块进行并行编码以提高计算速度,同时对全局块进行循环编码以节省GPU内存。

我们进行了大量实验,将RetNet与Transformer及其变种进行了比较。在语言建模方面的实验结果表明,RetNet在缩放曲线和上下文学习方面始终具有竞争力。此外,RetNet的推理成本与序列长度无关。对于一个7B模型和8k序列长度,RetNet的解码速度比带有键值缓存的Transformer快8.4倍,并节省了70%的内存。在训练过程中,RetNet相比标准的Transformer节省了25-50%的内存,并且比高度优化的FlashAttention[DFE+22]加速了7倍。此外,RetNet的推理延迟对于批量大小不敏感,可以实现巨大的吞吐量。这些有趣的特性使得RetNet成为大型语言模型中强有力的Transformer继任者。

图2:RetNet使得"不可能三角形"成为可能,实现了训练并行性、良好的性能和低推理成本的同时存在。

2.保留网络

保留网络(RetNet)由L个相同的块堆叠而成,其布局与Transformer[VSP+17]中的布局相似(即残差连接和预LayerNorm)。每个RetNet块包含两个模块:多尺度保留(MSR)模块和前馈网络(FFN)模块。我们在以下章节介绍MSR模块。给定输入序列x = x1 · · · x|x|,RetNet以自回归的方式对序列进行编码。首先,将输入向量{xi} |x| i=1打包成X0 = [x1, · · · , x|x|] ∈ R |x|×dmodel,其中dmodel是隐藏维度。然后我们计算上下文化的向量表示Xl = RetNetl(Xl−1),其中l ∈ [1, L]。

2.1.保留机制

在这个部分,我们介绍了一种具有循环和并行性的保留机制。因此,我们可以以并行的方式训练模型,同时循环地进行推理。给定输入X ∈ R |x|×dmodel,我们将其投影到一维函数v(n) = Xn · wV。考虑一个将v(n)映射到o(n)的序列建模问题,其中通过状态sn进行映射。为了简化表示,我们用vn和on表示v(n)和o(n)。我们以循环方式来表达这种映射: 其中†表示共轭转置。该公式在训练实例内很容易进行并行化。总结起来,我们从如式(1)所示的循环建模开始,然后推导出如式(4)所示的并行公式。我们将原始映射v(n) 7→ o(n)视为向量,并得到保留机制如下。

其中Θ是Θ的复共轭,D ∈ R |x|×|x| 将因果屏蔽和相对距离的指数衰减组合为一个矩阵。与自注意力类似,这种并行表示使我们能够高效地使用GPU训练模型。

保留的循环表示 如图3b所示,提出的机制也可以表示为循环神经网络(RNN),这对于推理是有利的。对于第n个时间步,我们循环地获得输出为:

保留的块状循环表示 对于加速训练,特别是对于长序列,可以使用并行表示和循环表示的混合形式。我们将输入序列划分为块。在每个块内,我们按照并行表示(公式(5))进行计算。相反,跨块的信息按照循环表示(公式(6))进行传递。具体而言,设B表示块的长度。我们通过以下方式计算第i个块的保留输出:

图3:RetNet的双重形式。"GN"是GroupNorm的缩写。

2.2.门控多尺度保留

我们在每一层中使用h = dmodel/d保留头,其中d是头的维度。这些头使用不同的参数矩阵WQ,WK,WV ∈ R d×d 。此外,多尺度保留(MSR)为每个头分配不同的γ。为了简化起见,我们使不同层之间的γ相同,并保持固定。此外,我们添加了一个swish门[HG16,RZL17]来增加保留层的非线性。形式上,给定输入X,我们将该层定义为:

其中WG,WO ∈ R dmodel×dmodel是可学习的参数,GroupNorm [WH18]对每个头的输出进行归一化,遵循[SPP+19]中提出的SubLN。请注意,头部使用多个γ尺度,这导致了不同的方差统计。因此,我们单独对头部输出进行归一化。保留的伪代码总结如图4所示。

保留分数的归一化 我们利用GroupNorm的尺度不变性来提高保留层的数值精度。具体来说,在GroupNorm中乘以一个标量值不会影响输出和反向梯度,即GroupNorm(α ∗ headi) = GroupNorm(headi)。我们在公式(5)中实现了三个归一化因子。首先,我们将QK⊺归一化为QK⊺/√d。其次,我们用D˜nm = Dnm/√Pn i=1 Dni替换D。第三,设R表示保留分数R = QK⊺ ⊙ D,我们将其归一化为R˜nm = Rnm/max(|Pn i=1 Rni|,1)。然后,保留输出变为Retention(X) = RV˜。上述技巧不会影响最终结果,同时稳定正向和反向传递的数值流动,因为具有尺度不变性的属性。

图4:保留的三种计算范式的伪代码。

2.3.保留网络的整体架构

对于一个L层的保留网络,我们堆叠多尺度保留(MSR)和前馈网络(FFN)来构建模型。形式上,将输入序列{xi} |x| i=1通过一个词嵌入层转换为向量。我们使用打包的嵌入X0 = [x1, · · · , x|x| ] ∈ R |x|×dmodel作为输入,并计算模型输出XL:

其中LN(·)是LayerNorm [BKH16]。FFN部分的计算方式为FFN(X) = gelu(XW1)W2,其中W1、W2是参数矩阵。

训练:训练过程中我们使用并行(公式(5))和分块递归(公式(7))表示。在序列或块内部进行并行化能够有效利用GPU加速计算。更重要的是,分块递归在长序列训练中非常有用,既节省了计算量也节省了内存消耗。

推理:推理过程中采用了递归表示(公式(6)),这非常适合自回归解码。O(1)复杂度减少了内存和推理延迟,同时实现了相同的结果。

表1:从不同角度比较模型。RetNet实现了训练并行化、恒定的推理成本、线性的长序列内存复杂度和良好的性能。

2.4.与之前方法的关系与区别

表1从不同的角度比较了RetNet和先前的方法。比较结果反映了图2中呈现的"不可能的三角形"。此外,由于分块递归表示,RetNet对于长序列具有线性的内存复杂度。我们还总结了与具体方法的比较如下。

Transformer RetNet的并行表示与Transformer [VSP+17]具有相似的思路。最相关的Transformer变体是Lex Transformer [SDP+22],它使用xPos作为位置嵌入。如公式(3)所述,retention的推导与xPos相一致。与注意力机制相比,retention去除了softmax并启用了递归形式,从而显著提高了推理性能。

S4与公式(2)不同,如果Qn和Kn不具备内容感知性,公式可以退化为S4 [GGR21],其中O = (QK⊺, QAK⊺, .., QA|x|-1K⊺) * V。

线性注意力:线性注意力变体通常使用各种核函数ϕ(qi)ϕ(kj )/P|x|n=1 ϕ(qi)ϕ(kn)来替代softmax函数。然而,线性注意力在有效编码位置信息方面存在困难,使得模型的性能较低。此外,我们重新审视序列建模的过程,而不是旨在近似softmax。

AFT/RWKV Attention Free Transformer (AFT) 将点积注意力简化为逐元素操作,并将softmax移动到键向量上。RWKV使用指数衰减替换了AFT的位置嵌入,并在训练和推理过程中对模型进行递归运行。相比之下,retention保留了高维状态以编码序列信息,这有助于提高模型的表达能力和性能。

xPos/RoPE 与为Transformer提出的相对位置嵌入方法相比,公式(3)呈现了与xPos [SDP+22]和RoPE [SLP+21]类似的表达式。

Sub-LayerNorm 如公式(8)所示,retention层使用了Sub-LayerNorm [WMH+22]来对输出进行归一化。由于多尺度建模导致头部具有不同的方差,我们用GroupNorm替代了原始的LayerNorm。

3.实验

我们进行了语言建模的实验来评估RetNet。我们使用各种基准测试来评估所提出的架构,包括语言建模性能以及下游任务的零/少样本学习。此外,对于训练和推理,我们还比较了速度、内存消耗和延迟。

表2:语言建模实验中模型的大小和学习超参数。

图5:随着模型规模的扩大,困惑度下降。我们经验证实,当模型大小大于2B时,RetNet往往优于Transformer。

3.1.设置

**参数分配:**为了进行公平比较,我们重新分配了MSR和FFN中的参数。在这里,为了简单起见,我们用d表示dmodel。在Transformer中,自注意力机制中有大约4d^2个参数,其中WQ、WK、WV、WO∈R^d×d,FFN中有8d^2个参数,其中中间维度为4d。相比之下,RetNet中的retention层有8d^2个参数,其中WQ、WK∈R^d×d,WG、WV∈R^d×2d,WO∈R^2d×d。注意,V的头部维度是Q、K的两倍。扩展的维度通过WO投影回d。为了保持参数数量与Transformer相同,RetNet中FFN的中间维度为2d。同时,我们在实验中将头部维度设置为256,即查询和键的维度为256,值的维度为512。为了公平比较,我们在不同的模型大小中保持γ相同,其中γ=1−e^(linspace(log 1/32,log 1/512,h))∈R^h,而不是公式(8)中的默认值。

**语言模型训练:**如表2所示,我们从头开始训练各种规模的语言模型(即1.3B、2.7B和6.7B)。训练语料库是The Pile [GBB+20]、C4 [DMI+21]和The Stack [KLBA+22]的精选编译版本。我们添加了标记来表示序列的开始。训练的批量大小为4M个标记,最大长度为2048。我们使用100B个标记,即25k步来训练模型。我们使用AdamW [LH19]优化器,设置β1=0.9,β2=0.98,并将权重衰减设置为0.05。预热步数为375,使用线性学习率衰减。为了保证训练稳定性,参数初始化遵循DeepNet [WMD+22]的方法。实现基于TorchScale [MWH+22]。我们使用512个AMD MI200 GPU来训练模型。

3.2.与Transformer的对比

**语言建模:**如图5所示,我们针对基于Transformer和RetNet的语言模型在验证集上报告了困惑度。我们使用三种模型大小(即1.3B、2.7B和6.7B)展示了规模曲线。RetNet在性能上与Transformer相当。更重要的是,结果表明RetNet在大小扩展方面更具优势。除了性能,RetNet在我们的实验中训练非常稳定。实验结果表明,对于大型语言模型,RetNet是Transformer的一个强有力的竞争对手。经验证实,当模型大小大于2B时,RetNet开始优于Transformer。我们还在附录B中总结了具有不同上下文长度的语言建模结果。

**零样本和少样本在下游任务上的评估。**我们还在广泛的下游任务上比较了语言模型。我们使用6.7B模型进行了零样本和4样本的学习评估。如表3所示,数据集包括HellaSwag (HS) [ZHB+19]、BoolQ [CLC+19]、COPA [WPN+19]、PIQA [BZB+20]、Winograd、Winogrande [LDM12]和StoryCloze (SC) [MRL+17]。准确率与图5中的语言建模困惑度一致。在零样本和上下文学习设置中,RetNet在性能上与Transformer相当。

表3:基于Transformer和RetNet进行零样本学习和少样本学习。模型大小为6.7B。 表4:Transformer(Trm)、带有FlashAttention的Transformer(Trm+FlashAttn)和RetNet的训练成本。我们报告了内存消耗和训练吞吐量(每秒处理的单词数;wps)。

3.3.训练代价

如表4所示,我们比较了Transformer和RetNet的训练速度和内存消耗,其中训练序列长度为8192。我们还与FlashAttention [DFE+22]进行了比较,FlashAttention通过重新计算和内核融合来提高速度并减少GPU内存IO。相比之下,我们使用纯粹的PyTorch代码实现了RetNet,并将内核融合或类似FlashAttention的加速留待未来工作。我们使用基于块的递归保留表示,如方程(7)所述。块大小设置为512。我们使用八个Nvidia A100-80GB GPU进行评估,因为FlashAttention对A100进行了高度优化。对于6.7B和13B模型,启用了张量并行。

实验结果表明,在训练过程中,RetNet比Transformer更节约内存,并且具有更高的吞吐量。即使与FlashAttention相比,RetNet在速度和内存成本方面仍然具有竞争力。此外,由于不依赖于特定的内核,RetNet在其他平台上的高效训练变得更加容易。例如,我们在具有良好吞吐量的AMD MI200集群上训练RetNet模型。值得注意的是,RetNet通过高级实现(如内核融合)有进一步降低成本的潜力。

3.4.推断代价

如图6所示,我们比较了Transformer和RetNet在推断过程中的内存消耗、吞吐量和延迟。Transformer重复使用先前解码的令牌的KV缓存。RetNet使用了方程(6)中描述的递归表示。我们在实验中使用A100-80GB GPU对6.7B模型进行评估。图6显示,RetNet在推断成本方面优于Transformer。

内存消耗 如图6a所示,由于KV缓存,Transformer的内存消耗呈线性增长。相比之下,RetNet的内存消耗即使对于长序列也保持一致,需要更少的GPU内存来容纳RetNet。RetNet的额外内存消耗几乎可以忽略不计(约为3%),而模型权重占据了97%的内存。

吞吐量 如图6b所示,随着解码长度的增加,Transformer的吞吐量下降。相比之下,通过利用保留的递归表示,RetNet在解码过程中具有更高且与长度无关的吞吐量。

延迟 延迟是部署中的一个重要指标,它极大地影响用户体验。我们在图6c中报告了解码延迟。实验结果显示,增加批量大小会使Transformer的延迟增加。此外,随着输入长度的增加,Transformer的延迟增长更快。为了使延迟可接受,我们必须限制批量大小,这会损害Transformer的整体推断吞吐量。相比之下,RetNet的解码延迟优于Transformer,并且在不同的批量大小和输入长度下基本保持不变。

图6:使用6.7B模型的Transformer和RetNet的推断成本。RetNet在内存消耗、吞吐量和延迟方面优于Transformer。

(a) Transformer和RetNet的GPU内存消耗。

(b) Transformer和RetNet的吞吐量。

(c) 不同批量大小下的推断延迟。

3.5.与Transformer变体的对比

除了Transformer之外,我们还将RetNet与各种高效的Transformer变体进行了比较,包括线性变换器 [KVPF20]、RWKV [PAA+23]、H3 [DFS+22]和Hyena [PMN+23]。所有模型都有200M的参数,16层,隐藏维度为1024。对于H3,我们将头维度设置为8。对于RWKV,我们使用TimeMix模块替代自注意力层,同时保持FFN层与其他模型一致,以进行公平比较。我们使用0.5M标记的批量大小进行了10k步的训练。大多数超参数和训练语料库与第3.1节中保持一致。

表5报告了领域内验证集和其他领域外语料库(例如,Project Gutenberg 2019-2022 (PG22) [SDP+22]、QMSum [ZYY+21]、GovReport [HCP+21]、SummScreen [CCWG21, SSI+22])上的困惑度数据。总体而言,RetNet在不同的数据集上表现优于先前的方法。RetNet不仅在领域内语料库上取得了更好的评估结果,还在几个领域外数据集上获得了更低的困惑度。这种优越的性能使RetNet成为Transformer的一个强有力的继任者,除了显著降低成本的好处(第3.3和3.4节)。

此外,我们还讨论了比较方法的训练和推断效率。假设d表示隐藏维度,n表示序列长度。在训练过程中,RWKV的令牌混合复杂度为O(dn),而Hyena的复杂度为O(dn log n),使用快速傅里叶变换进行加速。上述两种方法通过使用逐元素操作来减少训练FLOPS,以换取建模能力。与之相比,保留模型的分块递归表示复杂度为O(dn(b + h)),其中b是分块大小,h是头维度,我们通常将b设置为512,h设置为256。对于大型模型(即较大的d)或序列长度,额外的b + h对性能影响可以忽略不计。因此,RetNet的训练效率非常高,同时不会牺牲建模性能。在推断过程中,与其他高效架构相比,Hyena与Transformer具有相同的复杂度(即每步为O(n)),而其他方法可以进行O(1)的解码。

表5:语言建模的困惑度结果。RetNet在领域内评估集和各种领域外语料库上均优于其他架构。

表6:领域内和领域外语料库上的消融实验结果。

3.6.消融研究

我们对RetNet的各种设计选择进行了消融实验,并在表6中报告了语言建模的结果。评估设置和指标与第3.5节相同。

架构我们消融了方程(8)中描述的swish gate和GroupNorm。表6显示上述两个组件提高了最终性能。首先,门控模块对于增强非线性和改进模型能力至关重要。需要注意的是,在去除门控后,我们使用了与Transformer相同的参数分配。其次,保留中的分组归一化平衡了多头输出的方差,提高了训练稳定性和语言建模结果。

多尺度衰减 方程(8)表明我们使用不同的γ作为保留头部的衰减率。在消融研究中,我们检查了去除γ衰减(即"- γ decay")和在所有头部上应用相同的衰减率(即"- multi-scale decay")。具体来说,消除γ衰减等效于γ = 1。在第二种设置中,我们将所有头部的γ设置为127/128。表6表明,衰减机制和使用多个衰减率都可以提高语言建模性能。

头部维度 从方程(1)的循环透视来看,头部维度意味着隐藏状态的记忆容量。在消融研究中,我们将默认头部维度从256降低到64,即查询和键的维度为64,值的维度为128。我们保持隐藏维度dmodel不变,因此头部的数量增加了。表6中的实验结果显示,较大的头部维度能够实现更好的性能。

4.总结

在本文中,我们提出了适用于序列建模的保留网络(RetNet),它能够实现多种表示,包括并行、循环和分块循环。与Transformer相比,RetNet在推理效率(记忆、速度和延迟),训练并行化和性能方面取得了显著的改进。上述优势使RetNet成为大型语言模型的理想替代品,特别是考虑到O(1)推理复杂度带来的部署优势。在未来,我们希望在模型大小[CDH+22]和训练步骤方面扩展RetNet。此外,通过压缩长期记忆,保留可以与结构化提示[HSD+22b]有效配合工作。我们还将使用RetNet作为骨干架构来训练多模态大型语言模型[HSD+22a, HDW+23, PWD+23]。此外,我们对在各种边缘设备上部署RetNet模型(如手机)很感兴趣。

A.超参数

B.不同上下文长度的分组结果

如表8所示,我们报告了不同上下文长度下的语言建模结果。为了使数据可比较,我们使用2048个文本块作为评估数据,并仅计算最后128个标记的困惑度。实验结果表明,RetNet在不同上下文长度下的性能优于Transformer。此外,RetNet可以利用更长的上下文来获得更好的结果。

表8:使用不同上下文长度的RetNet和Transformer的语言建模困惑度。结果显示,RetNet在序列长度方面具有一致的优势。

相关推荐
Watermelo6172 分钟前
详解js柯里化原理及用法,探究柯里化在Redux Selector 的场景模拟、构建复杂的数据流管道、优化深度嵌套函数中的精妙应用
开发语言·前端·javascript·算法·数据挖掘·数据分析·ecmascript
乐之者v7 分钟前
leetCode43.字符串相乘
java·数据结构·算法
QQ同步助手7 分钟前
如何正确使用人工智能:开启智慧学习与创新之旅
人工智能·学习·百度
AIGC大时代10 分钟前
如何使用ChatGPT辅助文献综述,以及如何进行优化?一篇说清楚
人工智能·深度学习·chatgpt·prompt·aigc
流浪的小新15 分钟前
【AI】人工智能、LLM学习资源汇总
人工智能·学习
A懿轩A1 小时前
C/C++ 数据结构与算法【数组】 数组详细解析【日常学习,考研必备】带图+详细代码
c语言·数据结构·c++·学习·考研·算法·数组
古希腊掌管学习的神1 小时前
[搜广推]王树森推荐系统——矩阵补充&最近邻查找
python·算法·机器学习·矩阵
云边有个稻草人1 小时前
【优选算法】—复写零(双指针算法)
笔记·算法·双指针算法
半盏茶香1 小时前
在21世纪的我用C语言探寻世界本质 ——编译和链接(编译环境和运行环境)
c语言·开发语言·c++·算法
martian6651 小时前
【人工智能数学基础篇】——深入详解多变量微积分:在机器学习模型中优化损失函数时应用
人工智能·机器学习·微积分·数学基础