结合RNN与Transformer双重优点,深度解析大语言模型RWKV

本文分享自华为云社区《【云驻共创】昇思MindSpore技术公开课 RWKV 模型架构深度解析》,作者:Freedom123。

一、前言

Transformer模型作为一种革命性的神经网络架构,于2017年由Vaswani等人 提出,并在诸多任务中取得了显著的成功。Transformer的核心思想是自注意力机制,通过全局建模和并行计算,极大地提高了模型对长距离依赖关系的建模能力。但是Transformer在处理长序列时面临内存和计算复杂度的问题,因为其复杂度与序列长度呈二次关系一直未业内人员所诟病。今天我们学习的RWKV,它作为对Transformers模型的替代,正在引起越来越多的开发人员的关注。RWKV模型以简单、高效、可解释性强等特点,成为自然语言处理领域的新宠。下面让我们一起来学习RWKV模型。

二、RWKV简介

RWKV(Receptance Weighted Key Value)是一个结合了RNN与Transformer双重优点的模型架构,由香港大学物理系毕业的彭博首次提出。其名称源于其 Time-mix 和 Channel-mix 层中使用的四个主要模型元素:R(Receptance):用于接收以往信息;W(Weight):是位置权重衰减向量,是可训练的模型参数; K(Key):是类似于传统注意力中 K 的向量; V(Value):是类似于传统注意力中 V 的向量。

RWKV模型作为一种革新性的大型语言模型,结合了RNN的线性复杂度和Transformer的并行处理优势,引入了Token shift和Channel Mix机制来优化位置编码和多头注意力机制,解决了传统Transformer模型在处理长序列时的计算复杂度问题。RWKV在多语言处理、小说写作、长期记忆保持等方面表现出色,可以主要应用于自然语言处理任务,例如文本分类、命名实体识别、情感分析等。

三、 RWKV模型的演进

RWKV模型之所以发展到今天的结构经历了五个阶段,从RNN结构到LSTM结构,到GRU结构,到GNMT结构,到Transformers结构,最后到RMKV结构,下面我们一一来学习每种模型结构,务必做到对模型结构都有一个清晰的认识。

1.RNN结构

RNN(Recurrent Neural Network)是循环神经网络的缩写,是一种深度学习模型,特别适用于处理序列数据。RNN具有记忆功能,可以在处理序列数据时保留之前的信息,并将其应用于当前的计算中。RNN的特点在于其具有循环连接的结构,使得信息可以在网络中传递并被持续更新。RNN由一个个时间步组成,每个时间步的输入不仅包括当前时刻的输入数据,还包括上一个时间步的隐藏状态,这样就可以在处理序列数据时考虑到上下文信息,这种结构使得RNN能够处理不定长的序列数据,如自然语言文本、时间序列数据等。RNN结构如下图所示:

左边是RNN网络,右边是RNN网络按时序展开的形式,为什么要按照时序展开?主要是RNN中 隐状态更新需要依赖上一次的隐状态信息,就是我们理解的记忆信息。RNN 的基本结构包括一个隐藏层,其中的神经元通过时间步骤连接,允许信息从一个时间步骤传递到下一个时间步骤。RNN 在每个时间步骤上接收一个输入并输出一个隐藏状态。这个隐藏状态包含了网络在当前时间步骤所看到的序列的信息。这个隐藏状态可以被用作下一个时间步骤的输入。对于一个时间步骤 t,RNN 的隐藏状态的计算如下:

尽管 RNN 具有处理序列数据的能力,但它们在处理长序列时会面临梯度消失或梯度爆炸的问题。这是因为通过时间反向传播时,梯度可能会迅速缩小或增大,导致模型难以学习长期依赖关系。为了解决梯度消失问题,出现了一些改进的 RNN 变体,如长短时记忆网络(LSTM)和门控循环单元(Gated Recurrent Unit,GRU)。这些模型通过引入门控机制,允许网络选择性地记住和遗忘信息,从而更有效地处理长序列。

2.LSTM结构

LSTM全称Long Short Term Memory networks,是普通RNN的变体,可以有效解决长期依赖的问题。LSTM的核心是元胞(cell)状态,输入的信息从上方的水平线经过元胞,期间只与其他实线进行少量交互,表示一些线性变换,这使得输入的信息能够较完整的保存下来,也就是说可以保留长期记忆。而LSTM对信息进行选择性的保留,是通过门控机制进行实现的。门结构可以控制通过元胞信息的多少。它实际上是对输入信息进行线性变换后,再通过一个sigmoid层来实现的,最终将输入最终转为一个系数向量,值的范围在0~1,可以理解为保留的信息的占比。如果值为0,则表示将对应的信息全部丢弃,如果为1则表示将对应的信息全部保留。LSTM共有三种门结构,分别是遗忘门、输入门、输出门,LSTM结构如下图所示:

遗忘门用来控制在元胞(cell) 状态里哪些信息需要进行遗忘,以使在流动的过程中进行适当的更新。它接收作为输入参数,通过sigmoid层得到对应的遗忘门的参数。具体公式如下:

接下来就需要更新细胞状态了。首先LSTM需要生成一个用来更新的候选值,记为,通过tanh层来实现。然后还需要一个输入门参数来决定更新的信息,同样通过sigmoid层实现。最后将相乘得到更新的信息,同时将上面得到的遗忘门和旧元胞状态相乘,以忘掉其中的一些信息,二者相结合,便得到更新后的状态,具体公式如下:

LSTM需要计算最后的输出信息,该输出信息主要由元胞状态决定,但是需要经过输出门进行过滤处理。首先要将元胞状态的值规范化到[-1,1],这通过tanh层来实现。然后依然由一个sigmoid层得到输出门参数,最后将和规范化后的元胞状态进行点乘,得到最终过滤后的结果,具体公式如下:

3.GRU结构

GRU (Gated Recurrent Unit)是一种用于循环神经网络(RNN)的门控机制,旨在解决长期依赖问题并缓解梯度消失或爆炸现象。GRU的结构比LSTM (Long Short-Term Memory)更简单,它包含两个门:更新门(update gate)和重置门(reset gate)。更新门负责控制前一时刻的状态信息对当前时刻状态的影响,其值越大,表明引入的前一时刻状态信息越多。重置门则控制忽略前一时刻状态信息的程度,其值越小,表明忽略得越多。GRU结构如下图所示:

它只有两个门,对应输出更新门(update gate)向量:和重置门(reset gate)向量:,更新门负责控制上一时刻状态信息对当前时刻状态的影响,更新门的值越大说明上一时刻的状态信息带入越多。而重置门负责控制忽略前一时刻的状态信息的程度,重置门的值越小说明忽略的越多。接下来,"重置"之后的重置门向量 与前一时刻状态卷积,再将输入进行拼接,再通过激活函数tanh来将数据放缩到-1~1的范围内。这里包含了输入数据,并且将上一时刻状态的卷积结果添加到当前的隐藏状态,通过此方法来记忆当前时刻的状态。

最后一个步骤是更新记忆阶段,此阶段同时进遗忘和记忆两个步骤,使用同一个门控同时进行遗忘和选择记忆(LSTM是多个门控制) 。

4.GNMT结构

NMT是神经网络翻译系统,通常会含用两个RNN,一个用来接受输入文本,另一个用来产生目标语句,但是这样的神经网络系统有三个弱点:1.训练速度很慢并且需要巨大的计算资源,由于数量众多的参数,其翻译速度也远低于传统的基于短语的翻译系统(PBMT);2.对罕见词的处理很无力,而直接复制原词在很多情况下肯定不是一个好的解决方法;3.在处理长句子的时候会有漏翻的现象。而GNMT中,RNN使用的是8层(实际上Encoder是9层,输入层是双向LSTM。)含有残差连接的神经网络,残差连接可以帮助某些信息,比如梯度、位置信息等的传递。同时,attention层与decoder的底层以及encoder的顶层相连接,如下图所示:

GNMT encoder将输入语句变成一系列的向量,每个向量代表原语句的一个词,decoder会使用这些向量以及其自身已经生成的词,生成下一个词。encoder和decoder通过attention network连接,这使得decoder可以在产生目标词时关注原语句的不同部分。上面提到,多层堆叠的LSTM网络通常会比层数少的网络有更好的性能,然而,简单的错层堆叠会造成训练的缓慢以及容易受到剃度爆炸或梯度消失的影响,在实验中,简单堆叠在4层工作良好,6层简单堆叠性能还好的网络很少见,8层的就更罕见了,为了解决这个问题,在模型中引入了残差连接,如图:

一句话的译文所需要的关键词可能在出现在原文的任何位置,而且原文中的信息可能是从右往左的,也可能分散并且分离在原文的不同位置,因为为了获得原文更多更全面的信息,双向RNN可能是个很好的选择,在本文的模型结构中,只在Encoder的第一层使用了双向RNN,其余的层仍然是单向RNN,粉色的LSTM从左往右的处理句子,绿色的LSTM从右往左,二者的输出先是连接,然后再传给下一层的LSTM,如下图Bi-directions RNN示意图:

5.Transformers结构

Transformer模型是一种基于自注意力机制的神经网络模型,旨在处理序列数据,特别是在自然语言处理领域得到了广泛应用。Transformer模型的核心是自注意力机制(Self-Attention Mechanism),它允许模型关注序列中每个元素之间的关系。这种机制通过计算注意力权重来为序列中的每个位置分配权重,然后将加权的位置向量作为输出。模型结构上,Transformer由一个编码器堆栈和一个解码器堆栈组成,它们都由多个编码器和解码器组成。编码器主要由多头自注意力(Multi-Head Self-Attention)和前馈神经网络组成,而解码器在此基础上加入了编码器-解码器注意力模块。Transformer结构如下所示:

基于Transformer 结构的编码器和解码器结构上图所示,左侧和右侧分别对应着编码器(Encoder)和解码器(Decoder)结构。它们均由若干个基本的Transformer 块(Block)组成(对应着图中的灰色框)。这里N× 表示进行了N 次堆叠。每个Transformer 块都接收一个向量序列。主要涉及到如下几个模块:

1)嵌入表示层:对于输入文本序列,首先通过输入嵌入层(Input Embedding)将每个单词转换为其相对应的向量表示。通常直接对每个单词创建一个向量表示。由于Transfomer 模型不再使用基于循环的方式建模文本输入,序列中不再有任何信息能够提示模型单词之间的相对位置关系。在送入编码器端建模其上下文语义之前,一个非常重要的操作是在词嵌入中加入位置编码(Positional Encoding)这一特征。具体来说,序列中每一个单词所在的位置都对应一个向量。这一向量会与单词表示对应相加并送入到后续模块中做进一步处理。在训练的过程当中,模型会自动地学习到如何利用这部分位置信息。

2)注意力层:自注意力(Self-Attention)操作是基于Transformer 的机器翻译模型的基本操作,在源语言的编码和目标语言的生成中频繁地被使用以建模源语言、目标语言任意两个单词之间的依赖关系。给定由单词语义嵌入及其位置编码叠加得到的输入表示{xi ∈ Rd}ti=1,为了实现对上下文语义依赖的建模,进一步引入在自注意力机制中涉及到的三个元素:查询qi(Query),键ki(Key),值vi(Value)。在编码输入序列中每一个单词的表示的过程中,这三个元素用于计算上下文单词所对应的权重得分。直观地说,这些权重反映了在编码当前单词的表示时,对于上下文不同部分所需要的关注程度。

3)前馈层:前馈层接受自注意力子层的输出作为输入,并通过一个带有Relu 激活函数的两层全连接网络对输入进行更加复杂的非线性变换。实验证明,这一非线性变换会对模型最终的性能产生十分重要的影响。

其中W1, b1,W2, b2 表示前馈子层的参数。实验结果表明,增大前馈子层隐状态的维度有利于提升最终翻译结果的质量,因此,前馈子层隐状态的维度一般比自注意力子层要大。

4) 残差连接与层归一化:由Transformer 结构组成的网络结构通常都是非常庞大。编码器和解码器均由很多层基本的Transformer 块组成,每一层当中都包含复杂的非线性映射,这就导致模型的训练比较困难。因此,研究者们在Transformer 块中进一步引入了残差连接与层归一化技术以进一步提升训练的稳定性。具体来说,残差连接主要是指使用一条直连通道直接将对应子层的输入连接到输出上去,从而避免由于网络过深在优化过程中潜在的梯度消失问题。

Transformer 模型由于其处理局部和长程依赖关系的能力以及可并行化训练的特点而成为一个强大的替代方案,如 GPT-3、ChatGPT、GPT-4、LLaMA 和 Chinchilla 等都展示了这种架构的能力,推动了自然语言处理领域的前沿。尽管取得了这些重大进展,Transformer 中固有的自注意力机制带来了独特的挑战,主要是由于其二次复杂度造成的。这种复杂性使得该架构在涉及长输入序列或资源受限情况下计算成本高昂且占用内存。这也促使了大量研究的发布,旨在改善 Transformer 的扩展性,但往往以牺牲一些特性为代价。正是在此背景之下,一个由 27 所大学、研究机构组成的开源研究团队,联合发表论文《 RWKV: Reinventing RNNs for the Transformer Era 》,文中介绍了一种新型模型:RWKV(Receptance Weighted Key Value),这是一种新颖的架构,有效地结合了 RNN 和 Transformer 的优点,同时规避了两者的缺点。RWKV能够缓解 Transformer 所带来的内存瓶颈和二次方扩展问题,实现更有效的线性扩展,同时保留了使 Transformer 在这个领域占主导的一些性质。

四、 RWKV模型

RWKV是一个结合了RNN与Transformer双重优点的模型架构,是一个RNN架构的模型,但是可以像transformer一样高效训练。RWKV 模型通过 Time-mix 和 Channel-mix 层的组合,以及 distance encoding 的使用,实现了更高效的 Transformer 结构,并且增强了模型的表达能力和泛化能力。Time-mix 层与 AFT(Attention Free Transformer)层相似,采用了一种注意力归一化的方法,以消除传统 Transformer 模型中存在的计算浪费问题。Channel-mix 层则与 GeLU(Gated Linear Unit)层相似,使用了一个 gating mechanism 来控制每条通道的输入和输出。另外,RWKV 模型采用了类似于 AliBi 编码的位置编码方式,将每个位置的信息添加到模型的输入中,以增强模型的时序信息处理能力。这种位置编码方式称为 distance encoding,它考虑了不同位置之间的距离衰减特性,RWKV结构如下图所示:

这里我们以下图的自回归例子学习RWKV的推理过程,用 (x,y ) 表示样本数据和样本标签,图中有3对数据: (my,name) , (my name ,is ) 和 (my name is , Bob) . 另外,在语言模型中,标记偏移(token shift)是一种常见的技术,用于训练模型以预测给定上下文中下一个标记(单词、字符或子词单元)的任务。下图中的标记偏移技术是向右移动一个位置,生成的三个token-shift 为:"0 my","my name","name is"。为什么要进行标记偏移呢? 这是因为这样做具有递归嵌套的思想,比如:"name"向量与"my"向量有关,而"is"向量与"name"向量有关,所以"is"向量自然与"name"向量有关。好处是:给融入循环神经网络思想带来了便利的同时还保持了并行性。具体流程下面的Time-Mix模块和Channel-Mix模块会详细介绍。如下图所示,这两个模块是RWKV架构的主要模块。Time-Mix模块可以看成根据隐状态(State)生成候选预测向量,Channel-Mix模块则可以看成生成最终的预测向量。

1.Time Mixing模块

对于t时刻,给定单词和前一个单词,Time-Mix模块公式如下:

其中,即AFT注意力机制中的,且,不同的是,偏置值以控制以及距离越远权重越低,这相当于可学习的位置编码;u也是位置编码,表示t时刻认为过去哪些时刻比较重要;表示token-shift,由得出。因此,对于上述公式,给定的t时刻,唯一不能确定的就是, 因为需要得到前t-1个Key向量和Value向量。此时的Time-Mix模块看起来还是规范的注意力机制,但可以写成递归循环的形式。Time-Mix模块之间存在着从左往右的Statas传递,即引入了隐状态。其融合循环神经网络思想的Time-Mix模块的数据流动如下图所示,

的计算可以写成递归循环的形式:

其中,隐状态虽然依然需要t时刻之前的计算,但过程已经简化,可以将隐状态的计算独立出来,也降低了内存要求,并且转化为张量积计算也大大提升了并行性。论文作者也为WKV设计了特定的CUDA内核,还提出了实践技巧,下次讨论。

2.Channel Mixing模块

Channel-Mix模块可以看成是为了生成最终的预测向量。输入Time-Mix模块的向量和向量还没有上文信息,经由Time-Mix模块得出的向量和 向量融合了他们各自时刻之前的上文信息,我猜测Channel-Mix模块如它命名所示,将不同时刻的信息进一步融合得到最终的预测变量,其公式为:

因为向量和向量具有上文的信息,所以不用值Value向量,最后应用一个类似遗忘门的操作,丢弃不必要的历史信息。

3.RWKV的优势

1)高效训练和推理:RWKV 模型既可以像传统 Transformer 模型一样高效训练,也具有类似于 RNN 的推理能力。这使得 RWKV 模型可以支持串行模式和高效推理,也可以支持并行模式(并行推理训练)和长程记忆。

2)支持高效训练:RWKV 模型使用了 Time-mix 和 Channel-mix 层,以消除传统 Transformer 模型中存在的计算浪费问题。这使得 RWKV 模型在训练过程中具有更高的效率和更快的速度。

3)支持大规模自然语言处理任务:RWKV 模型可以处理大规模的自然语言处理任务,如文本分类、命名实体识别、情感分析等。

4)可扩展性强:RWKV 模型具有良好的可扩展性,可以方便地进行模型扩展和改进,以适应不同任务的需求。

4.RWKV模型参数

目前官方已经就RWKV开源了多个模型。主要是Raven系列模型,Raven是基于RWKV-4架构在Pile数据集上训练和微调的大模型,做过指令微调或者chat微调版本。此外,也包括了非Raven版本的RWKV-4的模型。

五、 RWKV模型代码阅读

1.RWKV模型推理代码

代码解释:

1~2行:引入代码需要的库

4行:对输出进行校验

6~9行:加载RWKV/rwkv-4-169m-pile模型,并且输入提示词

11~12行:运行模型,解码生成内容

13行:期望输出与真实输出内容进行校验

2.Channel Mixing模块代码:

x通道混合层接受与此标记对应的输入,以及x与前一个标记对应的输入,我们称之为last_x。last_x存储在这个 RWKV 层的state. 其余输入是学习RWKV 的 parameters。首先,我们使用学习的权重对x和进行线性插值last_x。我们将此插值x作为输入运行到具有平方 relu 激活的 2 层前馈网络,最后与另一个前馈网络的 sigmoid 激活相乘(在经典 RNN 术语中,这称为门控)。请注意,就内存使用而言,矩阵Wk,Wr,Wv包含几乎所有参数(1024×1024 matrices它们是矩阵,而其他变量只是 1024 维向量)。矩阵乘法(@在 python 中)贡献了绝大多数所需的计算。

3.Time mixing模块代码:

时间混合的开始类似于通道混合,通过将此标记的插入x到最后一个标记的x。然后我们应用学到的矩阵以获得"key", "value" and "receptance"向量。

六、与其他模型的比较

1.复杂度对比

从和Transformer,Reformer,Performer,Linear Transformers,AFT-full,AFT-local,MEGA等模型的复杂度比较中可以看的出来,RWKV模型的时间复杂度和空间负责度都是最低的,费别为O(Td)和O(d),其中T 表示序列长度,d 表示特征维度,c 表示 MEGA 的二次注意力块大小。

2.精度对比

RWKV 似乎可以像 SOTA transformer一样缩放。至少多达140亿个参数。在同等规模参数中,RWKV-4系列与Pythia和GPT-J比都是很有优势的,对比如下图所示:

3.推理速度和内存占用

RWKV网络与不同类型的Transformer性能的实验结果对比如下图所示。RWKV时间消耗随序列长度是线性增加,且时间消耗远小于各种类型的Transformer。

RWKV与Transformer预训练模型(BLOOM、OPT、Pythia)效果对比测试如下图所示。在六个基准测试中(Winogrande、PIQA、ARC-C、ARC-E、LAMBADA 和 SciQ),RWKV 与开源二次复杂度 transformer 模型 Pythia、OPT 和 BLOOM 具有相当的竞争力。RWKV 甚至在四个任务(PIQA、OBQA、ARC-E 和 COPA)中胜过了 Pythia 和 GPT-Neo。

下图显示,增加上下文长度会导致 Pile 上的测试损失降低,这表明 RWKV 能够有效利用较长的上下文信息。

七、小结

本节我们学习了RWKV模型,我们掌握了RWKV模型结构的整个演进过程,从最初的RNN结构,到LSTM结构,到GRU结构,到GNTM模型,到Transformers模型,最后到RWKV模型,我们学习了每种模型结构出现的原因,以及其对应的优势和不足。接下来,我们学习了RWKV模型,Time Mixing模块和Channel Mixing模块。我们通过学习RWKV模型的python代码,对RWKV模型从复杂度,精度,推理速度,内存占用等四个维度和其他模型进行了对比。

通过本节学习,我们对RWKV模型有了一个全面的认识,RWKV模型正在作为一颗在大模型领域的新星正在受到越来越多社区开发者的关注,希望RWKV模型在接下来的版本迭代过程中能给大家带来更多的惊喜。

点击关注,第一时间了解华为云新鲜技术

相关推荐
威化饼的一隅8 分钟前
【多模态】swift-3框架使用
人工智能·深度学习·大模型·swift·多模态
机器学习之心32 分钟前
BiTCN-BiGRU基于双向时间卷积网络结合双向门控循环单元的数据多特征分类预测(多输入单输出)
深度学习·分类·gru
MorleyOlsen2 小时前
【Trick】解决服务器cuda报错——RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED
运维·服务器·深度学习
伯牙碎琴2 小时前
智能体实战(需求分析助手)二、需求分析助手第一版实现(支持需求提取、整理、痛点分析、需求分类、优先级分析、需求文档生成等功能)
ai·大模型·agent·需求分析·智能体
愚者大大4 小时前
1. 深度学习介绍
人工智能·深度学习
liuming19924 小时前
Halcon中histo_2dim(Operator)算子原理及应用详解
图像处理·人工智能·深度学习·算法·机器学习·计算机视觉·视觉检测
聆思科技AI芯片4 小时前
实操给桌面机器人加上超拟人音色
人工智能·机器人·大模型·aigc·多模态·智能音箱·语音交互
长风清留扬5 小时前
机器学习中的密度聚类算法:深入解析与应用
人工智能·深度学习·机器学习·支持向量机·回归·聚类
程序员非鱼5 小时前
深度学习任务简介:分类、回归和生成
人工智能·深度学习·分类·回归·生成
γ..5 小时前
基于MATLAB的图像增强
开发语言·深度学习·神经网络·学习·机器学习·matlab·音视频