Tokenformer: 下一代Transformer架构

1. 导言

Transformer架构已经成为当今大模型的基石,不管是NLP还是CV领域,目前的SOTA模型基本都是基于Transformer架构的,比如NLP中目前的各种知名大模型,或者CV中的Vit等模型

本次介绍的论文标题为:Tokenformer: Rethinking Transformer Scaling with Tokenized Model Parameters," 顾名思义,本文提出了Tokenformer架构,其优势在于增量学习能力:在增加模型尺寸时,无需从头开始重新训练模型,大大降低了成本。 本文由北大和谷歌进行合作,一作是北大在读博士,考虑到代码已开源,因该是有一定含金量的。

2. Transformer vs Tokenformer - 结构比较

首先我们从顶层设计的角度,对于传统 Transformer 架构和 本文提出的 Tokenformer 架构进行比较,如下图所示:

2.1 Transformer 架构

自注意力机制是Transformer的核心,主要包括以下几个步骤:

  • 输入 :假设有一个长度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T的序列,每个token用一个维度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d的向量表示,记为矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> X ∈ R T × d X \in \mathbb{R}^{T \times d} </math>X∈RT×d。

  • 线性投影 : 将输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 通过三个不同的线性层,分别得到查询(Query)、键(Key)和值(Value)矩阵,其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q , W K , W V ∈ R d × d k W_Q, W_K, W_V \in \mathbb{R}^{d \times d_k} </math>WQ,WK,WV∈Rd×dk 是可学习的权重矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 是查询和键的维度。:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q = X ⋅ W Q , K = X ⋅ W K , V = X ⋅ W V Q = X \cdot W_Q, \quad K = X \cdot W_K, \quad V = X \cdot W_V </math>Q=X⋅WQ,K=X⋅WK,V=X⋅WV

  • 计算注意力分数: 通过查询和键的点积,再除以缩放因子 dk\sqrt{d_k}dk,并通过Softmax函数得到归一化的注意力权重:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Attention ( Q , K , V ) = softmax ( Q ⋅ K ⊤ d k ) ⋅ V \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q \cdot K^\top}{\sqrt{d_k}} \right) \cdot V </math>Attention(Q,K,V)=softmax(dk Q⋅K⊤)⋅V

  • 输出投影 : 将注意力输出通过一个线性层进行投影,其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> W O ∈ R d v × d W_O \in \mathbb{R}^{d_v \times d} </math>WO∈Rdv×d 是输出投影矩阵。:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> O = Attention ( Q , K , V ) ⋅ W O O = \text{Attention}(Q, K, V) \cdot W_O </math>O=Attention(Q,K,V)⋅WO

如上图所示,一个Transformer层主要由两个部分组成:

  1. 多头自注意力机制(Multi-Head Self-Attention) :输入首先经过一个线性投影模块,以计算注意力模块的输入,即矩阵 Q、K 和 V。然后利用子注意力机制计算出Token之间的权重
  2. 前馈神经网络(Feed-Forward Network, FFN) :对于注意力层的输出进行投影,计算出下一层的输入

2.2 Transformer 架构的缺陷

传统Transformer在处理token与参数的交互时,依赖于固定数量的线性投影,这限制了模型的扩展性,这句话本身较难理解,因此接下来详细论述架构的缺陷。

2.2.1 模型的拓展性是什么

模型的拓展性(Scalability)指的是模型在需要更强大性能时,能够有效地增加其规模(如参数数量、计算能力等)而不导致性能下降或计算成本过高的能力。

简而言之,拓展性好的模型可以在保持或提升性能的同时,灵活且高效地扩大其规模。

2.2.2 为什么说传统Transformer的固定线性投影限制了模型的扩展性

固定线性投影 指的是,Transformer中用来生成查询、键、值的权重矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q , W K , W V W_Q, W_K, W_V </math>WQ,WK,WV 是预先定义且固定的。这带来了以下几个限制:

  1. 参数数量固定 : 传统Transformer的线性层 <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q , W K , W V W_Q, W_K, W_V </math>WQ,WK,WV 的维度是固定的。例如,如果输入维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 增加,那么每一层transformer中线性层的维度、输出投影层的维度也必须进行修改。这意味着模型的整体参数数量会急剧增加。

  2. 需要重新训练 : 如果要增加模型的规模(如增加 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 或 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk),必须从头开始训练整个模型。耗时而且需要大量计算资源。尤其是随着模型规模的增加,重新训练的成本(时间和计算资源)呈指数级增长,导致在实际应用中难以持续扩展模型。

3. TokenFormer的解决方案

为了解决模型维度固定导致的模型缺乏拓展性的问题,TokenFormer 提出了一种创新的方法,通过将模型参数视为tokens ,并利用注意力机制来处理token与参数之间的交互,从而实现更高效、更灵活的模型扩展。

3.1 模型参数Token化

传统Transformer将参数(如 <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q , W K , W V W_Q, W_K, W_V </math>WQ,WK,WV)作为固定的权重矩阵来处理。而TokenFormer将这些参数表示为一组可学习的tokens。具体来说:

参数Tokens:原本transformer模型的Q、K、V投影层不再是固定的矩阵,而是转化为一组向量(tokens),例如:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> K P ∈ R n × d 1 , V P ∈ R n × d 2 K_P \in \mathbb{R}^{n \times d_1}, \quad V_P \in \mathbb{R}^{n \times d_2} </math>KP∈Rn×d1,VP∈Rn×d2

其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 是参数tokens的数量, <math xmlns="http://www.w3.org/1998/Math/MathML"> d 1 d_1 </math>d1 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> d 2 d_2 </math>d2 分别是输入和输出的维度。

3.2. Token-Parameter Attention(Pattention)层

Pattention层 是TokenFormer的核心创新,它通过注意力机制来处理token与参数之间的交互。从而替代原本的 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q , K , V Q, K, V </math>Q,K,V,具体过程如下:

  1. 查询与参数Tokens交互 : 输入tokens作为查询(Query) ,参数tokens作为键(Key)和值(Value) ,通过注意力机制进行交互,其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 是输入tokens, <math xmlns="http://www.w3.org/1998/Math/MathML"> K P K_P </math>KP 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> V P V_P </math>VP 是参数tokens, <math xmlns="http://www.w3.org/1998/Math/MathML"> Θ \Theta </math>Θ 是一种修改后的softmax函数,用于稳定优化:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Pattention ( X , K P , V P ) = Θ ( X ⋅ K P ⊤ ) ⋅ V P \text{Pattention}(X, K_P, V_P) = \Theta\left(X \cdot K_P^\top\right) \cdot V_P </math>Pattention(X,KP,VP)=Θ(X⋅KP⊤)⋅VP

  1. 注意力分数( <math xmlns="http://www.w3.org/1998/Math/MathML"> Θ \Theta </math>Θ部分)的计算如下:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> S i j = f ( A i j × τ ∑ k = 1 n ∣ A i k ∣ 2 ) , ∀ i , j ∈ { 1 , ... , n } S_{ij} = f\left( \frac{A_{ij} \times \tau}{\sum_{k=1}^n |A_{ik}|^2} \right), \quad \forall i, j \in \{1, \dots, n\} </math>Sij=f(∑k=1n∣Aik∣2Aij×τ),∀i,j∈{1,...,n}

其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> A = X ⋅ K P ⊤ A = X \cdot K_P^\top </math>A=X⋅KP⊤, <math xmlns="http://www.w3.org/1998/Math/MathML"> τ \tau </math>τ 是缩放因子(默认设置为 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n), <math xmlns="http://www.w3.org/1998/Math/MathML"> f f </math>f 是非线性函数(如GeLU)。

总结一下,Pattention层 的详细计算流程如上图所示,首先输入的tokens: <math xmlns="http://www.w3.org/1998/Math/MathML"> X ∈ R T × d 1 X∈R^{T×d_1} </math>X∈RT×d1与 <math xmlns="http://www.w3.org/1998/Math/MathML"> K P ∈ R n × d 1 K_P \in \mathbb{R}^{n \times d_1} </math>KP∈Rn×d1计算出注意力分数,而后计算结果和 <math xmlns="http://www.w3.org/1998/Math/MathML"> V P ∈ R n × d 2 V_P \in \mathbb{R}^{n \times d_2} </math>VP∈Rn×d2相乘,即:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Pattention ( X , K P , V P ) = Θ ( A ) ⋅ V P A = X ⋅ K P ⊤ \text{Pattention}(X, K_P, V_P) = \Theta(A) \cdot V_P \\ A = X \cdot K_P^\top </math>Pattention(X,KP,VP)=Θ(A)⋅VPA=X⋅KP⊤

4. 总体结构

为方便阅读再把图扔到这:

与传统transformer结构相同,其总体上也包括两层:多头自注意力层和前馈网络层。

4.1 多头自注意力(Single-Head Variant:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q = Pattention ( X , K P Q , V P Q ) K = Pattention ( X , K P K , V P K ) V = Pattention ( X , K P V , V P V ) X att = softmax ( Q ⋅ K ⊤ d ) ⋅ V O att = Pattention ( X att , K P O , V P O ) \begin{aligned} Q &= \text{Pattention}(X, K_{P_Q}, V_{P_Q}) \\ K &= \text{Pattention}(X, K_{P_K}, V_{P_K}) \\ V &= \text{Pattention}(X, K_{P_V}, V_{P_V}) \\ X_{\text{att}} &= \text{softmax}\left( \frac{Q \cdot K^\top}{\sqrt{d}} \right) \cdot V \\ O_{\text{att}} &= \text{Pattention}(X_{\text{att}}, K_{P_O}, V_{P_O}) \end{aligned} </math>QKVXattOatt=Pattention(X,KPQ,VPQ)=Pattention(X,KPK,VPK)=Pattention(X,KPV,VPV)=softmax(d Q⋅K⊤)⋅V=Pattention(Xatt,KPO,VPO)

其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> K P Q , V P Q K_{P_Q}, V_{P_Q} </math>KPQ,VPQ 等是不同投影的参数tokens。即,首先计算出Q, K, V,然后同样计算自注意力,然后将计算结果放入Pattention层。

4.2 前馈网络(Feed-Forward Network, FFN)

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> O ffn = Pattention ( X ffn , K P ffn , V P ffn ) O_{\text{ffn}} = \text{Pattention}(X_{\text{ffn}}, K_{P_{\text{ffn}}}, V_{P_{\text{ffn}}}) </math>Offn=Pattention(Xffn,KPffn,VPffn)

其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> X ffn X_{\text{ffn}} </math>Xffn 是经过Layer Normalization后的中间表示。

这里也可以看到,相对于Transformer,Tokenformer就是将所有的投影层从固定的全连接网络也变成了Pattention层

4.3 与transformer的比较

下方公式左侧代表传统Transformer的自注意力机制,右侧代表tokenformer的自注意力机制:

从上边的图中可以清楚看到,相对于transformer,本论文只是将投影层与连接层替换成了新的层。

5. 可扩展性

之前说过,相对于transformer,tokenformer主要是解决可拓展性的问题,那么假设我们要增加参数数量,或者要增加输入维度,tokenformer如何进行增量学习?

如上图所示,若需要扩展模型,可以简单地追加新的 <math xmlns="http://www.w3.org/1998/Math/MathML"> K P K_P </math>KP 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> V P V_P </math>VP tokens:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> K P scale = [ K P old , K P new ] V P scale = [ V P old , V P new ] K_{P_{\text{scale}}} = [K_{P_{\text{old}}}, K_{P_{\text{new}}}]\\ V_{P_{\text{scale}}} = [V_{P_{\text{old}}}, V_{P_{\text{new}}}] </math>KPscale=[KPold,KPnew]VPscale=[VPold,VPnew]

这样,模型的参数量可以按需扩展。

初始化策略:新增的参数tokens初始化为零,类似于LoRA技术(Low-Rank Adaptation),确保模型能够在保持原有知识的基础上,快速适应新的参数扩展。

6. 实验部分

与从零重训练的 Transformer 相比,如上图所示,Y 轴代表模型性能,X 轴代表训练成本。 蓝线代表使用 3000 亿个 token 从头开始训练的 Transformer 模型,不同的圆圈大小代表不同的模型大小。

其他线条代表 Tokenformer 模型,不同颜色代表不同的Token数量。例如,红线从 1.24 亿个参数开始,扩展到 14 亿个参数,其训练集为从300B token中抽样出的30B Token。最终版本模型的性能与相同规模的 Transformer 相当,但训练成本却大大降低。

黄线显示,使用 60B个Token来训练的增量版本在更低的训练成本下,性能已经比 Transformer 更优。

相关推荐
井底哇哇4 小时前
ChatGPT是强人工智能吗?
人工智能·chatgpt
Coovally AI模型快速验证4 小时前
MMYOLO:打破单一模式限制,多模态目标检测的革命性突破!
人工智能·算法·yolo·目标检测·机器学习·计算机视觉·目标跟踪
AI浩4 小时前
【面试总结】FFN(前馈神经网络)在Transformer模型中先升维再降维的原因
人工智能·深度学习·计算机视觉·transformer
可为测控5 小时前
图像处理基础(4):高斯滤波器详解
人工智能·算法·计算机视觉
一水鉴天5 小时前
为AI聊天工具添加一个知识系统 之63 详细设计 之4:AI操作系统 之2 智能合约
开发语言·人工智能·python
倔强的石头1065 小时前
解锁辅助驾驶新境界:基于昇腾 AI 异构计算架构 CANN 的应用探秘
人工智能·架构
orion-orion6 小时前
贝叶斯机器学习:高斯分布及其共轭先验
机器学习·统计学习
佛州小李哥6 小时前
Agent群舞,在亚马逊云科技搭建数字营销多代理(Multi-Agent)(下篇)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
说私域6 小时前
社群裂变+2+1链动新纪元:S2B2C小程序如何重塑企业客户管理版图?
大数据·人工智能·小程序·开源
程序猿阿伟7 小时前
《探秘鸿蒙Next:如何保障AI模型轻量化后多设备协同功能一致》
人工智能·华为·harmonyos