自称超越Transformer的新一代大模型RWKV是什么

论文地址:arxiv.org/pdf/2305.13048v2

项目地址:github

论文题目为:《RWKV: Reinventing RNNs for the Transformer Era》

自 Vaswani 等人于 2017 年首次提出 Attention Is All You Need 之后,基于 transformer 的强大的模型一直在不断地涌现,它们在 NLP 相关任务上的表现远远超过基于 RNN (Recurrent Neural Networks, 递归神经网络) 的 SoTA 模型,甚至多数认为 RNN 已死。而本文将介绍一个集 RNN 和 transformer 两者的优势于一身的全新网络架构 --RWKV!现已在 HuggingFace transformers 库中支持。

1 背景与动机

1.1 背景:简单介绍RNN和Transformer

RNN(Recurrent Neural Network,循环神经网络)

RNN(Recurrent Neural Network,循环神经网络)是一种专门用于处理序列数据的神经网络。它能够处理前后数据点之间的依赖关系,这使得RNN特别适合于处理时间序列数据或者任何形式的序列,如文本、语音或视频。RNN的核心特点是它具有循环连接,这允许网络在处理序列的每个元素时保持一定的"记忆"。

RNN的工作原理是,它在序列的每个时间步骤接收输入,并产生输出,同时将一部分信息传递到下一个时间步骤。这种传递的信息通常被称为"隐藏状态"(hidden state),它能够捕捉序列中过去的信息。RNN的这种结构使得它能够处理可变长度的序列,并且能够处理长期依赖关系。

Token

Token在自然语言处理(NLP)中是一个重要的概念,它指的是文本中的一个基本单位。在不同的上下文中,token可以有不同的含义:

  1. 单词或词汇单元:在一些NLP任务中,token可以是一个单词或者一个字符,这是文本分析的最小单位。
  2. 子词单元:在一些现代的NLP模型中,为了更好地处理词汇的变体和形态变化,会将单词进一步分割成更小的单元,这些单元被称为subword tokens。
  3. 标记化:在文本处理中,将原始文本转换成token序列的过程称为标记化(tokenization)。这是文本预处理的重要步骤,它使得机器学习模型能够理解和处理文本数据。

在RNN中,序列数据通常会先被标记化,转换成一系列的token,然后这些token被用作RNN的输入。RNN通过处理这些token序列,学习序列中的模式和依赖关系,进而用于各种NLP任务,如语言建模、机器翻译、情感分析等。

Transformer

Transformer 是一种深度学习模型,由 Vaswani 等人在 2017 年的论文《Attention Is All You Need》中首次提出。它主要用于处理序列数据,尤其在自然语言处理(NLP)领域取得了革命性的进展。Transformer 模型完全基于注意力机制(Attention Mechanism),摒弃了之前序列模型中常用的循环神经网络结构。

主要特点

  1. 自注意力机制(Self-Attention):Transformer 通过自注意力机制使模型能够在序列中的每个位置都同时考虑其他位置,这有助于捕捉序列内部的长距离依赖关系。自注意力机制的核心是计算序列中每个元素对其他所有元素的注意力分数,然后根据这些分数对元素进行加权求和。

  2. 并行化处理:由于自注意力机制不依赖于序列中元素之间的循环或递归调用,Transformer 可以高效地并行处理整个序列,这在传统的循环神经网络中是难以实现的。

  3. 编码器-解码器架构:标准的 Transformer 模型由编码器和解码器两个部分组成。编码器处理输入序列,解码器生成输出序列。两部分都由多个相同的层组成,每层都包含自注意力模块和前馈神经网络。

  4. 多头注意力(Multi-Head Attention):Transformer 通过多头注意力机制进一步提升模型的表达能力。它将查询(Query)、键(Key)和值(Value)通过不同的线性投影分割成多个头,然后并行计算每个头的注意力输出,最后将这些输出合并,提供更丰富的信息表示。

  5. 位置编码:由于 Transformer 本身无法捕捉序列中元素的顺序信息,因此需要加入位置编码来提供序列中每个元素的位置信息。位置编码通常是与输入嵌入相加的固定或可学习的向量。

  6. 层正规化(Layer Normalization)残差连接(Residual Connections):Transformer 在每个子层(自注意力层和前馈网络层)的输出上应用层正规化,并使用残差连接,有助于避免深层网络中的梯度消失问题,使得深层网络的训练成为可能。

1.2 动机

部分内容来自拥抱脸的介绍:RWKV -- transformer 与 RNN 的强强联合 (huggingface.co)

由于 RNN 在计算每一时刻的预测值时使用的都是同一组网络权重,因此 RNN 很难解决长距离序列信息的记忆问题,这一定程度上也是训练过程中梯度消失导致的。为解决这个问题,相继有新的网络架构被提出,如 LSTM 或者 GRU,其中 transformer 是已被证实最有效的架构。

在 transformer 架构中,不同时刻的输入 token 可以在 self-attention 模块中并行处理。首先 token 经过 Q、K、V 权重矩阵做线性变换投影到不同的空间,得到的 Q、K 矩阵用于计算注意力分数 (通过 softmax,如下图所示),然后乘以 V 的隐状态得到最终的隐状态,这种架构设计可以有效缓解长距离序列问题,同时具有比 RNN 更快的训练和推理速度。

但是:

  • Transformer的局限性:尽管Transformer在自然语言处理(NLP)任务中取得了革命性的进展,但其自注意力机制的计算复杂度随着序列长度呈二次方增长,这在处理长序列时会导致显著的内存和计算负担。
  • RNN的局限性:RNN虽然在内存和计算需求上呈线性增长,但因为难以并行处理和可扩展性差,通常无法达到与Transformer相同的性能。

因此提出RWKV,RWKV 的灵感来自于 Apple 公司的 Attention Free Transformer。RWKV 该架构经过精心简化和优化,可以转换为 RNN。除此此外,为使 RWKV 性能媲美 GPT,还额外使用了许多技巧,例如 TokenShiftSmallInitEmb (使用的完整技巧列表在 官方 GitHub 仓库的 README 中 说明)。对于 RWKV 的训练,现有的项目仓库可以将参数量扩展到 14B,并且迭代修了 RWKV-4 的一些训练问题,例如数值不稳定性等。

2 RWKV架构

2.1 线性注意力机制

RWKV 模型架构与经典的 transformer 模型架构非常相似 (例如也包含 embedding 层、Layer Normalization、用于预测下一 token 的因果语言模型头、以及多个完全相同的网络层等),唯一的区别在于注意力层,它与传统的 transformer 模型架构完全不同,因此 RWKV 的注意力计算公式也不一样。

线性注意力机制是Transformer模型中自注意力机制的一个变体,旨在减少计算复杂度,特别是针对序列长度的二次方增长问题。在标准的Transformer模型中,自注意力的计算复杂度是O(T^2d),其中T是序列长度,d是特征维度。这种计算复杂度在处理长序列时会迅速变得不可行。线性注意力机制通过将复杂度降低到O(Td),使得模型能够更高效地处理长序列。

基本原理

线性注意力机制的核心思想是将传统的点积注意力(dot-product attention)替换为一种更高效的计算方式,同时保持对序列中各个元素间关系的捕捉能力。在点积注意力中,每个元素对其他所有元素的注意力是通过计算它们的点积并应用softmax函数来实现的。而在线性注意力中,这种计算被替换为一种更直接的方法。

计算过程

  1. 键向量和查询向量的变换:首先,对于序列中的每个元素,我们将其表示为查询(Q)、键(K)和值(V)向量。这些向量可以通过输入序列的线性变换得到。

  2. 注意力分数的计算:在传统的自注意力中,注意力分数是通过计算查询和所有键的点积得到的。在线性注意力中,我们使用一种线性化的方法来近似这种点积。一种常见的方法是使用一个可学习的权重向量w来与键向量进行点积,然后将结果与查询向量进行点积,以此来模拟传统的点积注意力:

    在线性注意力中,这个计算可以被近似为:

    其中W是一个可学习的权重矩阵。

  3. 值向量的加权求和:计算完注意力分数后,我们使用这些分数对值向量进行加权求和,得到最终的输出。

优点

  • 计算效率:线性注意力机制显著降低了计算复杂度,从O(T^2d)降低到O(Td),使得模型能够更高效地处理长序列。
  • 内存效率:由于计算复杂度的降低,线性注意力也减少了内存的使用,这对于大规模的模型和长序列尤为重要。

缺点

  • 精度损失:由于线性化近似,线性注意力可能会损失一些精度,尤其是在捕捉序列中复杂依赖关系时。
  • 灵活性限制:相比于传统的自注意力机制,线性注意力在模拟不同元素间复杂交互的能力上可能有所限制。

2.2 模型公式化

RWKV可以被公式化为Transformer或RNN,这使得它在训练时能够并行化计算,并在推理时保持线性复杂度。

模型公式化的组成

  1. Receptance (R): 接收向量,用于捕捉和存储过去的信息。
  2. Weight (W): 位置权重衰减向量,一个可训练的参数,用于模拟时间衰减。
  3. Key (K): 键向量,在传统的注意力机制中用于与查询向量计算关系分数。
  4. Value (V): 值向量,在注意力机制中用于与计算得到的权重相乘,生成输出。

公式化过程

RWKV模型的核心是其独特的注意力机制,即WKV操作符,它将传统的点积注意力替换为一种线性注意力形式。以下是RWKV模型中一些关键的公式化步骤:

  1. 时间混合 (Time Mixing):

    • 时间混合层通过RWKV操作符结合了时间维度上的混合,允许模型在序列的不同时间步之间传递信息。
    • 公式化可以表示为:
    • 其中,w 是时间衰减因子,u 是当前时间步的加权因子。
  2. 通道混合 (Channel Mixing):

    • 通道混合层则处理特征维度上的混合,允许模型在不同特征通道之间共享信息。
    • 公式化可以表示为:
    • 其中,Rt′和 Kt′分别是通道混合层的接收向量和键向量。
  3. 输出门控 (Output Gating):

    • 输出门控通过sigmoid函数控制信息流,增强模型对信息的选择性传递。
    • 公式化可以表示为:
    • 其中,σσ 表示sigmoid函数,WoWo 是输出权重。
  4. 序列计算 (Sequential Computation):

    • RWKV模型在序列的每个时间步上递归地计算上述操作,从而实现序列的动态处理。
    • 公式化可以表示为:
    • 其中,HtHt 是第t步的隐藏状态,LayerNorm是层归一化操作。
  5. 传统的 RNN 模型无法并行训练,而 RWKV 更像一个 "线性 GPT",因此比 GPT 训练得更快。

  6. 传统的 RNN 模型无法利用很长距离的上下文信息 (LSTM 用作语言模型时也只能有效处理约 100 个 token),而 RWKV 可以处理数千个甚至更多的 token

2.3 参数规模

研究者们将RWKV模型的参数规模扩展到140亿,这是迄今为止训练的最大密集RNN,并且发现其性能与同样规模的Transformer相当。

3 性能与效率

3.1 基准测试

论文通过在多个NLP任务上的测试,展示了RWKV在大规模模型上的性能和效率。

3.2 预训练模型

作者发布了从1.69亿到140亿参数的预训练模型,并在Pile数据集上进行了训练

4 技术细节

4.1 时间混合和通道混合

RWKV模型由堆叠的残差块组成,每个块包含时间混合和通道混合子块,这些子块利用过去的信息。
一个RWKV块(左)和完整的RWKV剩余块内的元素,配备了一个用于语言建模的最终头部

4.2 RWKV操作符

模型中的WKV操作符与传统的注意力机制相似,但通过相对位置和时间衰减向量来修改,以实现循环行为。
用于语言建模的RWKV架构

4.3 输出门控

在时间混合和通道混合块中使用sigmoid激活函数的输出门控来控制信息流。

5 训练与推理

5.1 训练阶段

  1. 并行化训练

    • RWKV模型利用Transformer架构的优势,实现训练过程中的并行化。这与传统的RNN不同,后者由于其递归性质,在训练时通常需要逐步迭代,难以实现并行处理。
    • 并行化处理可以显著加快训练速度,尤其是在处理大规模数据集时。
  2. 时间并行模式

    • 时间并行模式允许模型在训练时同时处理序列中的所有元素,这得益于RWKV的线性注意力机制,它不需要在时间步之间交换信息。
    • 这种模式下,模型的计算复杂度主要来自于矩阵乘法操作,这些操作可以很容易地在现代硬件上并行执行。
  3. 优化策略

    • 为了提高训练效率,RWKV模型采用了多种优化策略,包括自定义CUDA内核和小型初始化嵌入等。
    • 自定义CUDA内核可以针对特定的计算任务优化性能,而小型初始化嵌入有助于模型更快地收敛。

5.2 推理阶段

推理阶段是模型将学到的知识应用到新数据上,进行预测或决策的过程。RWKV模型在推理时采用以下策略:

  1. 序列化推理

    • 与训练阶段的并行化不同,RWKV在推理时采用序列化处理,这与RNN的处理方式类似。
    • 在序列化推理中,模型逐个处理序列中的元素,每个元素的输出依赖于之前的计算结果,这使得模型在处理长序列时具有线性的时间复杂度。
  2. 时间序列模式

    • 时间序列模式允许模型在推理时利用其RNN结构的优势,通过递归地更新内部状态来处理序列数据。
    • 这种模式下,模型可以有效地处理长序列,同时保持较低的内存和计算需求。
  3. 输出门控和状态更新

    • RWKV模型在每个时间步使用输出门控机制来控制信息的流动,这有助于模型在推理时更加专注于重要的信息。
    • 状态更新是RWKV模型推理的核心,模型通过更新其内部状态来捕捉序列中的长期依赖关系。
  4. 效率与性能

    • RWKV模型在推理时展现出了高效的性能,这得益于其线性复杂度的计算特性和优化的算法设计。
    • 这种设计使得RWKV模型在处理实际应用中的长序列数据时,既能保持较高的准确率,又能实现快速响应。

6 优化策略

6.1 自定义内核

  • 为了提高计算效率,特别是在执行WKV操作时,作者开发了自定义的CUDA内核。
  • 这些内核针对特定的计算任务进行了优化,利用GPU的并行处理能力,以加速模型的训练和推理过程。

6.2 小初始化嵌入

  • 在训练的初期阶段,作者采用了小型初始化嵌入的方法,即用较小的值初始化嵌入矩阵。
  • 这种方法有助于模型从初始状态快速收敛,因为它减少了初始阶段的噪声,并允许模型更平稳地开始学习过程。

6.3 时间并行模式

  • RWKV模型在训练时采用时间并行模式,这意味着模型可以同时处理序列中的所有元素。
  • 这种并行化处理减少了训练时间,因为它允许模型在多个时间步上并行执行计算,而不是逐个时间步顺序执行。

总结

  • RWKV为序列数据处理提供了一种新的高效且可扩展的架构,通过线性复杂度的注意力机制和有效的训练动态,展示了与传统Transformer相当的性能。

以上是读论文的内容,简单的记录下这一最新的网络,目前不少研究是基于该框架的,希望有多一个新的浪头

相关推荐
九圣残炎11 分钟前
【从零开始的LeetCode-算法】1456. 定长子串中元音的最大数目
java·算法·leetcode
lulu_gh_yu17 分钟前
数据结构之排序补充
c语言·开发语言·数据结构·c++·学习·算法·排序算法
成富33 分钟前
文本转SQL(Text-to-SQL),场景介绍与 Spring AI 实现
数据库·人工智能·sql·spring·oracle
丫头,冲鸭!!!37 分钟前
B树(B-Tree)和B+树(B+ Tree)
笔记·算法
Re.不晚41 分钟前
Java入门15——抽象类
java·开发语言·学习·算法·intellij-idea
CSDN云计算1 小时前
如何以开源加速AI企业落地,红帽带来新解法
人工智能·开源·openshift·红帽·instructlab
艾派森1 小时前
大数据分析案例-基于随机森林算法的智能手机价格预测模型
人工智能·python·随机森林·机器学习·数据挖掘
hairenjing11231 小时前
在 Android 手机上从SD 卡恢复数据的 6 个有效应用程序
android·人工智能·windows·macos·智能手机
小蜗子1 小时前
Multi‐modal knowledge graph inference via media convergenceand logic rule
人工智能·知识图谱
SpikeKing1 小时前
LLM - 使用 LLaMA-Factory 微调大模型 环境配置与训练推理 教程 (1)
人工智能·llm·大语言模型·llama·环境配置·llamafactory·训练框架