RoPE外推的缩放法则 ------ 尝试外推RoPE至1M上下文
知乎 :RoPE外推的缩放法则 ------ 尝试外推RoPE至1M上下文
arXiv :[2310.05209] Scaling Laws of RoPE-based Extrapolation
本文主要介绍 RoPE 外推的缩放法则,相关的背景、理论、验证、思考。主要内容包括四点:一,RoPE 外推的近期相关工作;二,放大和缩小 RoPE旋转角的底数(全文简称base)并在原始长度上续训都会改善其外推效果;三,在原始长度上续训,RoPE 的外推效果和 base 大小之间的关系;四,在更长长度上续训,RoPE 的外推效果和 base 大小之间的关系。
与以往的外推研究不同,本文并没有聚焦一个具体的外推方案,而是给出了一套改进RoPE外推能力的框架,及其对应的数学解释、实验验证。在这个框架下,本文不仅给出了 任意base 任意续训长度时 模型外推表现如何,同时给出了 给定期望上下文长度时应该如何调整RoPE实现定长外推,没有给定期望上下文长度时应该如何调整RoPE实现不定长外推。
1. 引言背景:RoPE的外推研究
1.1 基础:RoPE 与 外推
关于RoPE的提出、原理、解释、实现等的内容,笔者已经在先前关于 在预训练阶段改进RoPE外推 的系列博客中,给出了详细论述,详情可参考 Transformer位置编码(基础、意义)(由于LLaMA基于RoPE给出了良好的初始化参数,并且已有的研究主要聚焦微调和测试阶段的RoPE改进,因此笔者调整了研究方向,完成了这份工作)。这里为了明确符号使用,先简单回顾一下RoPE的基础公式以及RoPE外推问题的提出。
Transformer模型要求显式地编码模型的位置信息,这其中旋转位置编码(Rotary Position Embedding,RoPE)是当下最流行的位置编码方案。对于Transformer中 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t位置的query向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> q t = [ q t ( 0 ) , ⋯ , q t ( d − 1 ) ] ∈ R d \bm{q}_t=\begin{bmatrix}q_t^{\tiny(0)},\cdots,q_t^{\tiny(d-1)}\end{bmatrix}\in\mathbb{R}^d </math>qt=[qt(0),⋯,qt(d−1)]∈Rd 以及 <math xmlns="http://www.w3.org/1998/Math/MathML"> s s </math>s位置的key向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> k s = [ k s ( 0 ) , ⋯ , k s ( d − 1 ) ] ∈ R d \bm{k}_s=\begin{bmatrix}k_s^{\tiny(0)},\cdots,k_s^{\tiny(d-1)}\end{bmatrix}\in\mathbb{R}^d </math>ks=[ks(0),⋯,ks(d−1)]∈Rd,RoPE首先将 <math xmlns="http://www.w3.org/1998/Math/MathML"> q t , k s q_t,k_s </math>qt,ks在特征维度方向上两两维度一组,每两个维度组成一个复数,对应复平面中的一个向量:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q ~ t = [ q ~ t ( 0 ) , ⋯ , q ~ t ( d / 2 − 1 ) ] q ~ t ( n ) = q t ( 2 n ) + i q t ( 2 n + 1 ) k ~ s = [ k ~ s ( 0 ) , ⋯ , k ~ s ( d / 2 − 1 ) ] k ~ s ( n ) = k s ( 2 n ) + i k s ( 2 n + 1 ) . \begin{equation}\begin{gathered} \tilde{\bm{q}}_t=\begin{bmatrix}\tilde{q}_t^{\tiny(0)},\cdots,\tilde{q}_t^{\tiny(d/2-1)}\end{bmatrix} \quad \tilde{q}_t^{\tiny(n)}=q_t^{\tiny(2n)}+iq_t^{\tiny(2n+1)} \\ \tilde{\bm{k}}_s=\begin{bmatrix}\tilde{k}_s^{\tiny(0)},\cdots,\tilde{k}_s^{\tiny(d/2-1)}\end{bmatrix} \quad \tilde{k}_s^{\tiny(n)}=k_s^{\tiny(2n)}+ik_s^{\tiny(2n+1)} \end{gathered}\text{.}\tag{1}\end{equation} </math>q~t=[q~t(0),⋯,q~t(d/2−1)]q~t(n)=qt(2n)+iqt(2n+1)k~s=[k~s(0),⋯,k~s(d/2−1)]k~s(n)=ks(2n)+iks(2n+1).(1)
接着RoPE通过,让 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ~ t , k ~ s \tilde{\bm{q}}_t,\tilde{\bm{k}}s </math>q~t,k~s和一个基于超参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \bm{\theta} </math>θ的复数旋转矩阵对应位相乘,实现绝对位置信息的注入。通过自注意力计算展开,如公式2所示,可以发现在attention score中,相对位置信息通过cos/sin的形式表示。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A t , s = R e [ ( q ~ t ⊙ e i t θ ) ⋅ ( k ~ s ⊙ e i s θ ) T ] = R e [ ∑ n = 0 d / 2 − 1 q ~ t ( n ) e i t θ n ( k ~ s ( n ) e i s θ n ) ∗ ] = R e [ ∑ n = 0 d / 2 − 1 q ~ t ( n ) k ~ s ( n ) ∗ e i ( t − s ) θ n ] = ∑ n = 0 d / 2 − 1 ( q t ( 2 n ) k s ( 2 n ) + q t ( 2 n + 1 ) k s ( 2 n + 1 ) ) cos ( t − s ) θ n + ( q t ( 2 n ) k s ( 2 n + 1 ) − q t ( 2 n + 1 ) k s ( 2 n ) ) sin ( t − s ) θ n . \begin{equation}\begin{aligned} \bm{A}{t,s}&=\mathrm{Re}\begin{bmatrix}\left(\tilde{\bm{q}}_t\odot{e^{it\bm{\theta}}}\right)\cdot\left(\tilde{\bm{k}}s\odot{e^{is\bm{\theta}}}\right)^T\end{bmatrix} \\ &=\mathrm{Re}\begin{bmatrix}\sum{n=0}^{d/2-1}{\tilde{q}_t^{\tiny(n)}e^{it\theta_n}\left(\tilde{k}s^{\tiny(n)}e^{is\theta_n}\right)^*}\end{bmatrix}=\mathrm{Re}\begin{bmatrix}{\sum{n=0}^{d/2-1}\tilde{q}_t^{\tiny(n)}\tilde{k}s^{\tiny(n)}{}^{*}e^{i(t-s)\theta_n}}\end{bmatrix} \\ &=\sum{n=0}^{d/2-1}{\begin{aligned}&\left(q_t^{\tiny(2n)}k_s^{\tiny(2n)}+q_t^{\tiny(2n+1)}k_s^{\tiny(2n+1)}\right)\cos{(t-s)\theta_n}+\\ &\left(q_t^{\tiny(2n)}k_s^{\tiny(2n+1)}-q_t^{\tiny(2n+1)}k_s^{\tiny(2n)}\right)\sin{(t-s)\theta_n}\end{aligned}} \end{aligned}\text{ .}\tag{2}\end{equation} </math>At,s=Re[(q~t⊙eitθ)⋅(k~s⊙eisθ)T]=Re[∑n=0d/2−1q~t(n)eitθn(k~s(n)eisθn)∗]=Re[∑n=0d/2−1q~t(n)k~s(n)∗ei(t−s)θn]=n=0∑d/2−1(qt(2n)ks(2n)+qt(2n+1)ks(2n+1))cos(t−s)θn+(qt(2n)ks(2n+1)−qt(2n+1)ks(2n))sin(t−s)θn .(2)
其中,复数旋转矩阵使用的超参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \bm{\theta} </math>θ满足:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ = [ θ 0 , ⋯ , θ d / 2 − 1 ] θ n = 1000 0 − 2 n / d . \begin{equation}\begin{aligned} \bm{\theta}=\begin{bmatrix}\theta_0,\cdots,\theta_{d/2-1}\end{bmatrix}\quad \theta_n=10000^{-2n/d} \end{aligned}\text{.}\tag{3}\end{equation} </math>θ=[θ0,⋯,θd/2−1]θn=10000−2n/d.(3)
尽管RoPE可以理论上可以编码任意长度的绝对位置信息,并且通过三角计算将任意长度的相对位置信息呈现出来,RoPE仍然存在外推问题(length extrapolation problem),即对于基于RoPE的大语言模型,测试长度超过训练长度之后,模型的效果会有显著的崩坏,具体表现为语言建模困惑度急剧攀升。对此,已经有很多研究给出了不同的来源解释以及应对策略。在笔者看来,这些应对策略的研究工作主要可以分为两个流派。流派一,限制注意力的工作,例如滑动窗口,以及各种滑窗的变体,软窗口、块窗口、平行窗口、箭形窗口 等;流派二,调整旋转角的工作,例如线性内插,以及在这之后陆续提出的NTK方法、调整base续训等。
1.2 流派一:限制注意力的工作
滑动窗口,毋庸置疑,是最简单有效的外推方法,在Transformer升级之路:7、长度外推性与局部注意力中被称为"超强基线",基于滑动窗口以及其变体的外推研究一直是外推研究中不可忽视的力量。不过由于研究兴趣相左(主要因为以下三点:Transformer提出的动机在于全局的感知,但是滑动窗口又限制了这一点;滑动窗口不具有可拆分性,与RoPE提出动机相悖;直到最近才集成进FlashAttention2),笔者对此了解不多,以下仅对其变体做简单介绍如下表所示。
除了最基础的滑动窗口外,表格列出了各种滑动窗口的变体,例如 ALiBi(与RoPE无关),是将 滑动窗口 与 T5 bias 结合,公式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A t , s = x t W Q W K T x s T + b ( s − t ) b = 2 − 8 m \begin{equation}\bm{A}{t,s}=\bm{x}{t}\bm{W}{Q}\bm{W}{K}^{T}\bm{x}_{s}^{T}+b(s-t) \quad b=2^{-\frac{8}{m}} \tag{4}\end{equation} </math>At,s=xtWQWKTxsT+b(s−t)b=2−m8(4)
通过让attention score加上一个线性偏置(每个注意力头不同)实现一个软词窗,让模型关注相对距离较近的,给予相对距离较远的惩罚。而 xPos 则是将 滑动窗口 和 RoPE 结合,公式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A t , s = R e [ ∑ n = 1 d / 2 ζ n t e i t θ n q ~ t ( n ) ( ζ n − s e i s θ n k ~ s ( n ) ) ∗ ] = ∑ n = 1 d / 2 R e [ q ~ t ( n ) k ~ s ( n ) ∗ ζ n t − s e i ( t − s ) θ n ] ζ n = 2 n / d + γ 1 + γ , γ = 0.4 , θ n = 10000 − 2 n / d \begin{equation}\begin{gathered} \bm{A}{t,s}=\mathrm{Re}\begin{bmatrix}\sum{n=1}^{d/2}{\zeta_n^{t}e^{it\theta_n}\tilde{q}_t^{\tiny(n)}\left(\zeta_n^{-s}e^{is\theta_n}\tilde{k}s^{\tiny(n)}\right)^*}\end{bmatrix}=\sum{n=1}^{d/2}{\mathrm{Re}\begin{bmatrix}\tilde{q}_t^{\tiny(n)}{\tilde{k}_s^{\tiny(n)}}^*\zeta_n^{t-s}e^{i(t-s)\theta_n}\end{bmatrix}} \\[1ex] \zeta_n=\frac{2n/d+\gamma}{1+\gamma}, \ \gamma=0.4 , \quad\theta_n={10000}^{-2n/d} \end{gathered}\tag{5}\end{equation} </math>At,s=Re[∑n=1d/2ζnteitθnq~t(n)(ζn−seisθnk~s(n))∗]=n=1∑d/2Re[q~t(n)k~s(n)∗ζnt−sei(t−s)θn]ζn=1+γ2n/d+γ, γ=0.4,θn=10000−2n/d(5)
通过让attention score乘上一个指数衰减的系数(每个维度不同)实现远距离抑制。xPos的可贵之处在于其提出的指数衰减系数是可拆分的。上述两个词窗变体都应用于预训练阶段,与之相对,主流的词窗方法往往都应用于测试阶段,例如,在xPos论文中,作者还提出了一个分块的词窗方法(Blockwise Causal Attention,BCA ),如图1所示,超出训练长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain后,每个token至多能看到 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain个,至少能看到 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train / 2 T_\text{train}/2 </math>Ttrain/2个token。应用BCA,xPos的效果会得到进一步提升。
虽然限制感知范围是一个比较直觉可行的方案,但是对于上下文学习这样的学习范式,上文(包含若干例子以及任务提示)中的每个例子都是需要感知到的,由此就引出了另外一种词窗变体,平行词窗(Parallel Context Window)。PCW最早提出平行词窗,PCW让输入的提示跟每个例子之间都会有注意力计算,但是例子与例子之间并不做注意力计算,即整个attention score上会有一个如图1所示的mask。同时,每个例子结尾的下标保持一致,以此为起点位置编码往两侧延伸,由此就实现了仿佛每个例子都是平行地贴在任务提示之前一样。在这个基础上又有NBCE,如图1所示,NBCE相当于通过朴素贝叶斯理论,对平行词窗做了一个拓展改进。但是这些平行词窗都是具备无序性,即打乱上文中示例的顺序并不改变最终的结果,仅针对上下文学习比较合适。
图1 几种有代表性的基于限制注意力的外推工作。
最近,有一种新型的词窗逐渐为人们所关注,即 箭形词窗( <math xmlns="http://www.w3.org/1998/Math/MathML"> Λ \Lambda </math>Λ-shaped window)。箭形词窗最早见于LM-Infinit,相较于原始的滑动窗口,箭形词窗在保留了邻近token的同时,还保留了最开始的token,StreamingLLM对此做出了一些解释,指出这样做是因为大模型倾向于对靠前的token给予较大的attention score,并且取得了惊人的外推效果。对于上述的平行词窗、箭形词窗等方法,可以肯定的是,它们有利于外推、有利于attention加速,并且在一些任务上取得了很好地效果。但是这里仍然存在一些问题:为什么RoPE结合原始词窗将相对长度限制在训练长度之后,困惑度仍然会随实际长度增加而增加?为什么attention score具有如此的倾向性?丢弃掉部分token是否会对某些任务产生负面影响?这些都有待后续的探讨
1.3 流派二:调整旋转角的工作
从最早的三角式位置编码(Sinusoidal Position Embedding)到旋转位置编码 RoPE,关于 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn的取值几乎一直雷打不动是 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n = 10000 − 2 n / d \theta_n={10000}^{-2n/d} </math>θn=10000−2n/d。然而线性内插(Linear Interpolation)及其同期工作打破了这一传统,开启调整旋转角改进RoPE外推能力研究的序幕,并且在过去的三个月内异军突起,立刻成为了RoPE外推改进,至少是微调阶段改进的主流,主要工作如下表所示。
线性内插本身的想法很简单,如公式6所示,通过让RoPE的位置下标去除以一个系数,把cos/sin内的取值约束到训练长度范围以内;
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A t , s = R e [ ∑ n = 0 d / 2 − 1 q ~ t ( n ) k ~ s ( n ) ∗ e i ( t − s ) λ θ n ] = R e [ ∑ n = 0 d / 2 − 1 q ~ t ( n ) k ~ s ( n ) ∗ e i ( t − s ) θ n λ ] λ = T tune T train \begin{equation}\bm{A}{t,s}=\mathrm{Re}\begin{bmatrix}{\sum\limits{n=0}^{d/2-1}\tilde{q}_t^{\tiny(n)}{\tilde{k}s^{\tiny(n)}}^{*}e^{i\frac{(t-s)}{\lambda}\theta_n}}\end{bmatrix}=\mathrm{Re}\begin{bmatrix}{\sum\limits{n=0}^{d/2-1}\tilde{q}t^{\tiny(n)}{\tilde{k}s^{\tiny(n)}}^{*}e^{i(t-s)\frac{\theta_n}{\lambda}}}\end{bmatrix} \quad \lambda=\frac{T\text{tune}}{T\text{train}} \tag{6}\end{equation} </math>At,s=Re[n=0∑d/2−1q~t(n)k~s(n)∗eiλ(t−s)θn]=Re[n=0∑d/2−1q~t(n)k~s(n)∗ei(t−s)λθn]λ=TtrainTtune(6)
但是,这个系数既可以理解为除在下标上,也可以理解为除在旋转角 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn上,通过一个常数让 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn缩小,意图去表征更长的上下文特征。由于 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn从小到大对应低频至高频的不同特征,除以一个常数显然过于简单,也由此后续的研究例如Giraffe,让每个维度乘上一个随维度自适应变化的系数。系数和维度之间满足幂函数的关系,因此该操作被称为幂校正(Power Scaling),此外Giraffe还对于校正后较小的 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn直接设为0。
相较于 线性内插的常数校正、Giraffe的幂校正,最早为NTK方法提出的指数校正在后续的研究中,得到了更广泛的应用与关注。相较于前两者,NTK方法提出的校正更加符合原始 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn的取值方式,如公式7所示,fixed NTK仅通过放大 旋转角的底数 (简称 base),即实现让不同维度的 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn自适应缩小,大的 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn缩小幅度较小仍对应高频特征,小的 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn缩小幅度较大仍适配低频特征。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A t , s = R e [ ∑ n = 0 d / 2 − 1 q ~ t ( n ) k ~ s ( n ) ∗ e i ( t − s ) θ n ] θ n = ( 10000 ⋅ α ) − 2 n / d \begin{equation}\bm{A}{t,s}=\mathrm{Re}\begin{bmatrix}{\sum{n=0}^{d/2-1}\tilde{q}_t^{\tiny(n)}{\tilde{k}_s^{\tiny(n)}}^{*}e^{i(t-s)\theta_n}}\end{bmatrix}\quad\theta_n={\left(10 000\cdot\alpha\right)}^{-2n/d} \tag{7}\end{equation} </math>At,s=Re[∑n=0d/2−1q~t(n)k~s(n)∗ei(t−s)θn]θn=(10000⋅α)−2n/d(7)
如公式7所示,这里base放大的系数是固定的,随着推理上下文的增长,可以通过动态放大base,让RoPE不断适应新的上下文长度,如公式8所示,这就是dynamic NTK。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> α t = max ( 1 , 2 ⌈ log 2 t T train ⌉ + 1 − 1 ) . \begin{equation}\alpha_t=\max\left(1, 2^{\left\lceil\log_2{\frac{t}{T_\text{train}}}\right\rceil+1}-1\right) \text{.}\tag{8}\end{equation} </math>αt=max(1,2⌈log2Ttraint⌉+1−1).(8)
两种NTK方法由于其校正方式适配原始RoPE的 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn取值,无需续训即可以外推,目前已集成进Huggingface的LLaMA实现,并且在Qwen等模型外推中得到了应用;公式8中的dynamic NTK实现就是Qwen的实现版本,其优点是不需要额外超参数。NTK方法深刻影响了后续的RoPE外推研究,例如Code LLaMA直接把base设为1000000,然后在16K长代码上续训实现128K外推,再例如近期的LLaMA2 Long,也是直接把base设成500000,然后在16K长上续训实现32K外推。
1.4 思考:base 与 外推
在第二个流派的研究中,base是一个关键的超参数 。一方面,几乎所有的调整旋转角的工作,除了Giraffe,都是通过调整base来实现调整每个维度上的旋转角 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn的效果;即使是Giraffe对 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn的调整,也是以自适应缩小和裁剪,即设为0,的形式。由此可以说明,模型本身对 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn的分布已有一定的感知了,过度的调整 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn,只会导致模型的学习效果崩坏;对此笔者曾经尝试过,采用调整 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn分布的方法,在微调阶段改进外推效果,发现最终结果在简单的下游任务上都差强人意。
另一方面,目前所有的调整base实现外推的方案都是放大base ,即通过缩小 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn实现外推。这些工作都声称,较小的 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn有利于捕捉长上下文对应的低频特征,这时候给予更长长度的续训既可以让模型拥有更强的外推能力。对此笔者有两个疑问,一者,更小 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn 和 更长的续训语料,谁才是让模型外推能力更强的主要因素;在已有的研究中,可以看到,fixed/dynamic-NTK,不需要续训既可以外推,这里 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn的调整是主要因素,而YaRN在调整base之后还需要64K长的续训,由此实现128K的外推,这里面谁是主要因素就有些模棱两可了。
再者,如果 放大base 缩小 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn,有利于模型捕捉低频特征,但是 那些对应周期远超出模型训练的上下文长度的低频特征,真的可以无师自通感知到那些低频特征对应的上下文关系吗?反之,为了让每个维度都感知到完整的周期信息,缩小base 放大 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn,岂不是更加有利于模型的外推?而这就与当今主流的研究趋势相矛盾了,那么 缩小base 放大 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn,又能取得怎么样的外推效果呢?
此外,上述外推研究,虽然都是基于RoPE外推问题提出的改进,但是对于外推问题的源头缺少进一步的探讨,理论上RoPE可以胜任任意长度序列的位置编码,但是其外推能力的不足究竟来源于何方?虽然在 线性内插、LM-Infinite 等研究中,提到了 RoPE存在超过训练长度attention score异常的现象,但这些工作并没有说明这些现象来源于 RoPE-based LLM 的哪一部分参数。是哪些参数导致了attention score异常,如何定位这些不良的参数,并把它们转化为可靠的参数,这些都是值得思考的问题。
2. 惊奇发现:缩放base都能外推
在本节中,笔者针对上述思考展开深入实验。模型方面,使用 LLaMA2 7B/13B 进行微调,硬件使用 32卡A100,并行策略使用 ZeRO3。优化器方面,使用 AdamW 优化器, <math xmlns="http://www.w3.org/1998/Math/MathML"> β 1 = 0.9 \beta_1=0.9 </math>β1=0.9, <math xmlns="http://www.w3.org/1998/Math/MathML"> β 2 = 0.999 \beta_2=0.999 </math>β2=0.999,即Pytorch中的默认参数。数据方面,本文使用the_pile数据集,以语言建模任务进行训练,使用测试数据中的books3数据集,使用语言建模困惑度进行测试。框架方面,使用实验室开源的CoLLiE作为训练框架,OpenCompass作为测试框架;使用FlashAttention2加速,在进行位置编码的过程中使用fp32精度。
对于7B模型,具体采用以下超参数微调,
'batch_size': 512K tokens, 'learning_rate': 2e-5, 'weight_decay': 0, 'max_grad_norm': 2.5,
'lr_scheduler_type': 'none', 'bf16': True, 'bf16_full_eval': True, 'train_epoch': 1024,
对于13B模型,具体则采用以下超参数微调:
'batch_size': 512K tokens, 'learning_rate': 2e-5, 'weight_decay': 0, 'max_grad_norm': 1.0,
'lr_scheduler_type': 'none', 'bf16': True, 'bf16_full_eval': True, 'train_epoch': 1024,
其中,默认续训长度和原始训练长度一致,保持 4096 tokens,下文中涉及线性内插以及16K续训的内容,会对应调整续训长度,保持batch中token数量不变。根据上述实验设定,笔者进行了如下的实验,并且得到了如 图2、图3 所示的最终结果。
2.1 放大base 可以外推
首先笔者进行的是放大base的实验,笔者发现使用更大的base,即使在原始长度上续训,也能显著的改进模型的外推效果,最终得到的结果如图2所示,这样的改进有以下一些特点:
一,模型的外推效果很好,甚至可以直接外推超过续训长度;这个效果和Code LLaMA使用16K续训但取得了100K长度的外推是相一致的。二,模型外推存在一个明显的上界,在这个范围内以内,模型语言建模的困惑度和准确率基本保持在一个稳定的范围内。但是一旦超过一个界限以后,模型的外推表现会严重的退化,表现在困惑度出现急剧的上升。三,模型的外推效果随base的变化稳步提升,并且随base变化比较均匀,随着base增大,模型能够稳定地外推到更长的上下文。四,相较于dynamic NTK,base放大并续训的方法,在超过外推上界后的崩坏趋势是远远超过dyanmic NTK的退化速度的;因此,对于放大base并续训,超过外推上界后的效果总是会落后于dynamic NTK的。但是在外推上界之内,该方案的效果是远远好于dynamic NTK的。
图2 更大base的RoPE在预训练长度文本上续训后的语言建模困惑度。
2.2 缩小base 也可以外推
接着笔者进行的是缩小base的实验,需要注意的是,尽管这和当前已有的外推研究方向相反,但是缩小base,并在原始训练长度上续训,仍然取得了显著的外推效果提升,如图3所示,能够将最大上下文窗口拓展至训练长度之外,并且相比放大base的外推,展现出以下的一些特点:
一,模型可以外推,但外推效果有限,超过训练长度后,模型的困惑度仍然会上升,但是上升的会比较平坦;并且base越小上升越慢,曲线越平缓。二,base缩小后的模型外推并不存在一个明显的上界,模型语言建模的困惑度和准确率始终随上下文长度增加稳步退化。三,缩小base取改善外推并不是一个均匀的过程,当base在10000至8000的范围内,模型外推效果变化非常的小,在8000到2608之间,模型外推效果有了一定程度的提升,在2608至1304的区间,提升加速,在1304至652区间,提升进一步加速,并且在base取500时最终超过dynamic NTK方案。四,对比dynamic NTK,虽然 10000至652的绝大多数base 都无法超过 dynamic NTK,但是当base取到足够小之后,其外推曲线会足够平缓以至于在48K长度上一致优于dynamic NTK。
图3 更小base的RoPE在预训练长度文本上续训后的语言建模困惑度。
2.3 小结:四个问题的提出
如果把上述两个结果放在一起,就会得到一个很有趣的结果。如图4折线图所示:最中间的base=10000 ,其外推效果 ,至少从微调角度 ,是最差的,但是往两头走,效果都会得到一定的提升。更为有趣的一点在于,两者的改进方式各有不同:对于放大base,虽然外推效果是稳步提升的,但是对于不同base存在明显的外推上界;对于缩小base,虽然外推效果的提升不是均匀的,但是获得的外推曲线却是没有明显崩坏拐点的。如果以 base和上下文长度为轴,如图4热力图所示,就会发现,对于更大的base,外推存在一条明显的连续的分界线,对于更小的base,模型外推效果的改善存在一个明显的跃迁阶段。总结上文的分析,笔者提出以下的四个问题:
图4 不同base的RoPE在预训练长度文本上续训后的外推效果汇总。在(a)和(b)中,第一行显示困惑度,而第二行显示准确率。 第一列显示了不同base在不同上下文长度上的外推表现,第二列以base和上下文长度为轴使用热力图将困惑度和准确率的走势进一步可视化。
- Q0 :从模型参数角度,RoPE外推问题的根源在哪?
- Q1 :就外推问题而言,10000是最差的base取值吗?
- Q2 :对于更大base的外推,是否存在一个base、训练长度、外推上界之间的数学关系式?
- Q3 :对于更小base的外推,base缩小时产生外推能力跃迁的取值是多少,base足够小之后能否实现模型的无限外推?
3. 理论解释:RoPE外推的缩放法则
从上述四个问题出发,笔者展开了本节四个小节的论述,分别对应上述四个问题的答案:临界维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d extra d_\text{extra} </math>dextra(对应3.2小节)、临界base <math xmlns="http://www.w3.org/1998/Math/MathML"> β 0 = 10000 log T train 2 π T tune 2 π \beta_0={10000}^{\log_{\frac{T_\text{train}}{2\pi}}{\frac{T_\text{tune}}{2\pi}}} </math>β0=10000log2πTtrain2πTtune(对应3.4小节)、RoPE外推的缩放法则2(对应3.3小节)、RoPE外推的缩放法则1(对应3.1小节);如果把上述四个答案可以组合在一起,就可以得到 扩展的RoPE外推缩放法则(对应3.4小节)。
声明:这里的缩放法则和著名的scaling law没有关系 ,只是base放大和缩小过程中外推效果的变化规律;在位置编码研究中,调整base和 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn大小往往也被称作scale,因此这里不幸重名了。
3.1 缩小base时的缩放法则
论及缩小base为什么能够增加模型的外推能力,首先需要思考base变化对于RoPE意味着什么。回看RoPE的公式,可以发现,base缩小意味着 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn变大,对应到RoPE刻画绝对位置和相对位置的三角函数周期变短。如果把不同维度的余弦波分别呈现出来,例如图5所示,即可发现,对于小的base,例如500,每个theta对应的周期都会被限制在4096,即训练长度以内;而大的base,例如10000,会存在 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn对应的周期超出训练长度。也因此,小的base可以在训练或者续训的时候见过所有的cos/sin内部的取值 ,无论是绝对位置的还是相对位置的,也因此q和k的不同维度都得到了充分的学习,让他们见过了完整的cos/sin值域,相较于原始的base设为10000的情形,每个维度不会出现在测试时没有见到的超出训练范围的位置编码,由此就实现了外推能力的提升。
图5 周期和训练长度的可视化。 假设一个基于RoPE的大语言模型,注意力头维度32,即有16个旋转角。每个子图将这些维度对应的cos周期(平行的紫色平面)与 4096训练长度(蓝色框)对比。(a)表示base=500时,所有周期都小于训练长度;(b)表示base=1000000时,部分周期超过训长度(红色)。
在这个过程当中,有三个节点非常的关键,分别是 <math xmlns="http://www.w3.org/1998/Math/MathML"> π / 2 \pi/2 </math>π/2, <math xmlns="http://www.w3.org/1998/Math/MathML"> π \pi </math>π, <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 π 2\pi </math>2π,只有当cos/sin内部的取值在训练阶段遍历0到 <math xmlns="http://www.w3.org/1998/Math/MathML"> π / 2 \pi/2 </math>π/2,模型才会意识到cos为负、sin不单调;只有当cos/sin内部的取值达到 <math xmlns="http://www.w3.org/1998/Math/MathML"> π \pi </math>π,模型才会意识到cos不单调、sin为负;只有当cos/sin内部的取值达到 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 π 2\pi </math>2π,模型才能感知到所有的可能的cos/sin取值,进而可能意识到每个维度位置编码的周期性表示(真正意识到周期性表示可能需要在训练长度范围内经历一轮以上的周期,这一点无法具体量化)。由此笔者得到了 RoPE外推的缩放法则1(Scaling Law for Smaller bases)。
定理 1. (base缩小时RoPE外推的缩放法则) 对于基于RoPE的大语言模型(RoPE-based LLMs),假设其预训练文本长度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain,如果在微调阶段将base调整为 <math xmlns="http://www.w3.org/1998/Math/MathML"> β < 10000 \beta<10000 </math>β<10000,并且使用预训练文本长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain续训,那么模型的外推能力将会提升。如果base缩小到如下的 <math xmlns="http://www.w3.org/1998/Math/MathML"> β 1 , β 2 , β 3 \beta_1,\beta_2,\beta_3 </math>β1,β2,β3,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> β 1 = 2 T train π , β 2 = T train π , β 3 = T train 2 π \begin{equation}\beta_1 = \frac{2 T_\text{train}}{\pi}, \quad\beta_2 = \frac{T_\text{train}}{\pi}, \quad\beta_3 = \frac{T_\text{train}}{2\pi}\tag{9}\end{equation} </math>β1=π2Ttrain,β2=πTtrain,β3=2πTtrain(9)
那么每个维度位置编码中,如下所示,cos/sin内部会依次实现从0到 <math xmlns="http://www.w3.org/1998/Math/MathML"> π / 2 , π , 2 π \pi/2,\pi,2\pi </math>π/2,π,2π的遍历,由此实现外推能力进一步的提升。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> π 2 ≤ T train θ n = T train β 1 − 2 n d π ≤ T train θ n = T train β 2 − 2 n d 2 π ≤ T train θ n = T train β 3 − 2 n d for n = 1 , ⋯ , d / 2 − 1 , \begin{equation}\begin{gathered}\frac{\pi}{2}\leq T_\text{train}\theta_n=T_\text{train}\beta_1^{-\frac{2n}{d}} \\ \pi\leq T_\text{train}\theta_n=T_\text{train}\beta_2^{-\frac{2n}{d}} \\ 2\pi\leq T_\text{train}\theta_n=T_\text{train}\beta_3^{-\frac{2n}{d}} \end{gathered}\ \text{for } n=1,\cdots,d/2-1\text{,}\tag{10}\end{equation} </math>2π≤Ttrainθn=Ttrainβ1−d2nπ≤Ttrainθn=Ttrainβ2−d2n2π≤Ttrainθn=Ttrainβ3−d2n for n=1,⋯,d/2−1,(10)
对于LLaMA2,训练长度为4096,可以依次求得这三个节点为 <math xmlns="http://www.w3.org/1998/Math/MathML"> β 1 = 2608 \beta_1=2608 </math>β1=2608、 <math xmlns="http://www.w3.org/1998/Math/MathML"> β 2 = 1304 \beta_2=1304 </math>β2=1304、 <math xmlns="http://www.w3.org/1998/Math/MathML"> β 3 = 652 \beta_3=652 </math>β3=652。而正是这三个点也对应了base缩小过程中,效果加速攀升的三个拐点。由此,笔者回答了Q3 的前半部分,找到了base缩小产生外推能力跃迁的取值,关于base缩小能否导致无限外推,这点则留置下文3.4小节的实验验证。
3.2 周期 与 临界维度
根据上述从周期角度得出的分析可以发现,对于RoPE外推,每个维度对应的旋转角是否在训练阶段就已经完成一个周期的旋转是一个非常关键的问题。而根据RoPE公式,以及如图5所示,维度越靠前对应的theta取值越大,周期越短,在训练阶段就可以见过全周期的信息;相反,在原始RoPE中,最靠后的一些维度并不会在训练时见过完整的cos/sin值域。因此对于RoPE-based LLMs,存在这样一个特征维度:这个维度之前的维度对应的 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn对应的cos/sin周期 <math xmlns="http://www.w3.org/1998/Math/MathML"> T n T_n </math>Tn,能够被涵盖在训练长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain范围内,而在这之后的维度则是周期长于训练长度。由此,对于后续的维度,当超过训练长度时,新加入token编码的绝对位置以及与前面token产生的相对位置信息是 out-of-distribution(OOD)的,也因此,如公式11所示,这些维度对应的attenion score分量是out-of-distribution的,也由此使得整个模型的attention score在超出训练长度之后产生显著崩坏。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A t , s = R e [ ∑ n = 0 d / 2 − 1 q ~ t ( n ) k ~ s ( n ) ∗ e i ( t − s ) θ n ⏟ full attention scores in RoPE ] = R e [ ∑ n = 0 d extra / 2 − 1 q ~ t ( n ) k ~ s ( n ) ∗ e i ( t − s ) θ n ⏟ reliable part for extrapolation + ∑ n = d extra / 2 d / 2 − 1 q ~ t ( n ) k ~ s ( n ) ∗ e i ( t − s ) θ n ⏟ OOD part for extrapolation ] . \begin{equation}\begin{aligned} \bm{A}{t,s}&=\mathrm{Re}\begin{bmatrix}\underbrace{\color{purple}{\sum{n=0}^{d/2-1}\tilde{q}t^{\tiny(n)}\tilde{k}s^{\tiny(n)}{}^{*}e^{i(t-s)\theta_n}}}\text{full attention scores in RoPE}\end{bmatrix} \\ &=\mathrm{Re}\begin{bmatrix}{\underbrace{\color{blue}{\sum{n=0}^{d_\text{extra}/2-1}\tilde{q}t^{\tiny(n)}\tilde{k}s^{\tiny(n)}{}^{*}e^{i(t-s)\theta_n}}}\text{reliable part for extrapolation}+\underbrace{\color{red}{\sum{n=d_\text{extra}/2}^{d/2-1}\tilde{q}_t^{\tiny(n)}\tilde{k}s^{\tiny(n)}{}^{*}e^{i(t-s)\theta_n}}}\text{OOD part for extrapolation}}\end{bmatrix} \end{aligned}\text{.}\tag{11}\end{equation} </math>At,s=Re⎣ ⎡full attention scores in RoPE n=0∑d/2−1q~t(n)k~s(n)∗ei(t−s)θn⎦ ⎤=Re⎣ ⎡reliable part for extrapolation n=0∑dextra/2−1q~t(n)k~s(n)∗ei(t−s)θn+OOD part for extrapolation n=dextra/2∑d/2−1q~t(n)k~s(n)∗ei(t−s)θn⎦ ⎤.(11)
因此,笔者将这个维度称为RoPE外推的临界维度,其形式定义以及计算方式如引理1所示。
引理 1. (临界维度的定义) 对于基于RoPE的大语言模型(RoPE-based LLMs),假设其预训练文本长度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain,自注意力头维度数量为 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> q t , k s ∈ R d \bm{q}t,\bm{k}s\in\mathbb{R}^d </math>qt,ks∈Rd。那么存在这样一个维度, <math xmlns="http://www.w3.org/1998/Math/MathML"> d extra d\text{extra} </math>dextra:前 <math xmlns="http://www.w3.org/1998/Math/MathML"> d extra d\text{extra} </math>dextra个维度 感知了对应维度上全周期的位置编码,后 <math xmlns="http://www.w3.org/1998/Math/MathML"> d − d extra d-d_\text{extra} </math>d−dextra个维度 只感知了对应维度上一个周期内的部分编码,如公式12所示。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> T n = 2 π θ n = 2 π ⋅ 10000 2 n d ≤ T train , for n = 0 , ⋯ , d extra / 2 − 1 , T n = 2 π θ n = 2 π ⋅ 10000 2 n d > T train , for n = d extra / 2 , ⋯ , d / 2 − 1 . \begin{equation}\begin{aligned} T_n=\frac{2\pi}{\theta_n}=2\pi\cdot{10000}^{\frac{2n}{d}}\leq T_\text{train},\quad&\text{ for }n=0,\cdots,d_\text{extra}/2-1\text{,} \\ T_n=\frac{2\pi}{\theta_n}=2\pi\cdot{10000}^{\frac{2n}{d}}>T_\text{train},\quad&\text{ for }n=d_\text{extra}/2,\cdots,d/2-1\text{.} \end{aligned}\tag{12}\end{equation} </math>Tn=θn2π=2π⋅10000d2n≤Ttrain,Tn=θn2π=2π⋅10000d2n>Ttrain, for n=0,⋯,dextra/2−1, for n=dextra/2,⋯,d/2−1.(12)
因此,对于基于RoPE的大语言模型,我们将 <math xmlns="http://www.w3.org/1998/Math/MathML"> d extra d_\text{extra} </math>dextra,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> q t , k s \bm{q}t,\bm{k}s </math>qt,ks中感知了全周期位置编码的维度的数量,称作 RoPE外推的临界维度 (critical dimension for RoPE-based extrapolation),其计算方式如公式13所示。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> d extra = 2 ⌈ d 2 log 10000 T train 2 π ⌉ . \begin{equation}d\text{extra}=2\left\lceil{\dfrac{d}{2}}\log{10000}{\dfrac{T_\text{train}}{2\pi}}\right\rceil \text{.}\tag{13}\end{equation} </math>dextra=2⌈2dlog100002πTtrain⌉.(13)
这时回看base缩小改进RoPE外推的过程可以发现,实际上这就是临界维度更新一个的过程。随着base缩小为 <math xmlns="http://www.w3.org/1998/Math/MathML"> β < 10000 \beta<10000 </math>β<10000, <math xmlns="http://www.w3.org/1998/Math/MathML"> d extra d_\text{extra} </math>dextra的底数随之缩小,由此 <math xmlns="http://www.w3.org/1998/Math/MathML"> d extra = 2 ⌈ d 2 log β T train 2 π ⌉ d_\text{extra}=2\left\lceil{\dfrac{d}{2}}\log_\beta{\dfrac{T_\text{train}}{2\pi}}\right\rceil </math>dextra=2⌈2dlogβ2πTtrain⌉ 单调不减。因此,每缩小一段base,对应 <math xmlns="http://www.w3.org/1998/Math/MathML"> d extra d_\text{extra} </math>dextra加2(RoPE两维度一组,一次更新两个维度),从而让两个维度感知到完整的位置信息;最终,base缩小到 <math xmlns="http://www.w3.org/1998/Math/MathML"> β 3 \beta_3 </math>β3, <math xmlns="http://www.w3.org/1998/Math/MathML"> d extra = d d_\text{extra}=d </math>dextra=d,每个维度都感知到全部的位置信息,模型外推能力取得飞跃式提升。本文中,临界维度是RoPE外推中最关键的概念 ;对于LLaMA2,根据其训练长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train = 4096 T_\text{train}=4096 </math>Ttrain=4096,注意力头维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d = 128 d=128 </math>d=128,可以得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> d extra = 92 d_\text{extra}=92 </math>dextra=92,即 LLaMA2前92维度 都感知了完整的位置信息 ,在外推时是比较可靠的,后36维度 由于没有感知完整的位置信息 是外推问题的根源 ,由此就回答了Q0,其直观效果留待后文4.1小节的验证。
3.3 放大base时的缩放法则
对于LLaMA2,由于前92个维度的周期都被涵盖在训练长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain以内,这就为续训提供了一个很好的初始化参数。前92个维度能够适应每个对应维度位置编码的周期变化 ,也由此在更长的周期上微调的时候,虽然这些维度没有见过完整的周期,但是他仍然可以表征这个周期内的位置信息;或者说,放大base虽然放大了周期,但所得到的 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n ( t − s ) \theta_n(t-s) </math>θn(t−s)仍然在原先预训练所见过的范围内。而后36个维度,不仅本身在训练过程中就没有见过全部的周期,存在参数学习过拟合的问题,而且在放大base放大周期后,更加无法感受到一个完整周期内的位置信息。因此我们可以根据临界维度对应 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ n \theta_n </math>θn在更新base后的周期 <math xmlns="http://www.w3.org/1998/Math/MathML"> T n T_n </math>Tn,求出模型外推的上限。由此,笔者就得到了 RoPE外推的缩放法则2(Scaling Law of Larger Bases)。
定理 2.(base放大时RoPE外推的缩放法则) 对于基于RoPE的大语言模型(RoPE-based LLMs),假设其预训练文本长度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain,如果在微调阶段将base调整为 <math xmlns="http://www.w3.org/1998/Math/MathML"> β > 10000 \beta>10000 </math>β>10000,并且使用预训练文本长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain续训,那么模型的外推能力将会提升。其外推上界(extrapolation upper bound),即最大支持的上下文长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T extra T_\text{extra} </math>Textra,可以表示为 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β的一个函数,如公式14所示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> T extra = 2 π ⋅ β d extra ⋅ 1 d = 2 π ⋅ β ⌈ d 2 log 10000 T train 2 π ⌉ ⋅ 2 d . \begin{equation}T_\text{extra}=2\pi\cdot\beta^{d_\text{extra}\cdot\frac{1}{d}}=2\pi\cdot\beta^{\left\lceil{\frac{d}{2}}\log_{10000}{\frac{T_\text{train}}{2\pi}}\right\rceil\cdot{\frac{2}{d}}} \text{.}\tag{14}\end{equation} </math>Textra=2π⋅βdextra⋅d1=2π⋅β⌈2dlog100002πTtrain⌉⋅d2.(14)
相反,如果为了让模型支持 <math xmlns="http://www.w3.org/1998/Math/MathML"> T extra T_\text{extra} </math>Textra的上下文长度,那么存在一个最小的base <math xmlns="http://www.w3.org/1998/Math/MathML"> β 0 \beta_0 </math>β0,如公式15所示。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> β 0 = 10000 log T train 2 π T ~ extra 2 π . \begin{equation}\beta_0={10000}^{\log_{\frac{T_\text{train}}{2\pi}}{\frac{\tilde{T}_\text{extra}}{2\pi}}} \text{.}\tag{15}\end{equation} </math>β0=10000log2πTtrain2πT~extra.(15)
由此笔者给出了外推上限和base之间的数学关系,根据这个关系,我们可以根据base确定外推上界,以及根据期望的上下文长度算出对应的base值。为了验证定理2的正确性,笔者记录了不同base取值下实际支持的最大上下文长度,对比基于定理2得到的理论外推上界,如图6所示,发现两者呈现惊人的重合,由此就回答了Q2。
图6 更大base的RoPE在原始长度上续训后得到的最大支持上下文长度(蓝线)和理论外推上界(红线)的比较;至base=800000,实际最大支持长度已超过测试序列长度100K。
3.4 扩展 与 临界base
定理1和定理2已经完整地表述了原始长度续训时,基于RoPE的大语言模型的外推效果,由于目前已有的微调阶段的外推工作,普遍涉及到更长长度的续训,因此作者也在这里给出更长长度续训时,扩展的RoPE外推的缩放法则(Extended Scaling Law for RoPE-based Extrapolation)。
定理 3. (扩展的RoPE外推的缩放法则) 对于基于RoPE的大语言模型(RoPE-based LLMs),假设其预训练文本长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain,对应临界维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d extra d_\text{extra} </math>dextra,如果在微调阶段将base调整为 <math xmlns="http://www.w3.org/1998/Math/MathML"> β > 1 \beta>1 </math>β>1,并且使用更长长度长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T tune ≥ T train T_\text{tune}\geq T_\text{train} </math>Ttune≥Ttrain的文本续训,那么模型的外推能力不降;当且仅当 <math xmlns="http://www.w3.org/1998/Math/MathML"> β = 10000 \beta=10000 </math>β=10000且 <math xmlns="http://www.w3.org/1998/Math/MathML"> T tune = T train T_\text{tune}=T_\text{train} </math>Ttune=Ttrain时,外推效果不变。此外,存在一个 临界base <math xmlns="http://www.w3.org/1998/Math/MathML"> β 0 \beta_0 </math>β0,根据 续训文本长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T tune T_\text{tune} </math>Ttune 和 预训练文本长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain 决定:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> β 0 = 10000 log T train 2 π T tune 2 π . \begin{equation}\beta_0={10000}^{\log_{\frac{T_\text{train}}{2\pi}}{\frac{T_\text{tune}}{2\pi}}}\text{.}\tag{16a}\end{equation} </math>β0=10000log2πTtrain2πTtune.(16a)
如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> β > β 0 \beta>\beta_0 </math>β>β0,外推上界根据 base取值 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β 和 临界维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d extra d_\text{extra} </math>dextra 决定:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> T extra = 2 π ⋅ β d extra ⋅ 1 d = 2 π ⋅ β ⌈ d 2 log 10000 T train 2 π ⌉ ⋅ 2 d . \begin{equation}T_\text{extra}=2\pi\cdot\beta^{d_\text{extra}\cdot\frac{1}{d}}= 2\pi\cdot\beta^{\left\lceil{\frac{d}{2}}\log_{10000}{\frac{T_\text{train}}{2\pi}}\right\rceil\cdot{\frac{2}{d}}}\text{.}\tag{16b}\end{equation} </math>Textra=2π⋅βdextra⋅d1=2π⋅β⌈2dlog100002πTtrain⌉⋅d2.(16b)
如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ≤ β 0 \beta\leq\beta_0 </math>β≤β0,外推上界就是续训长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T tune T_\text{tune} </math>Ttune,但是 临界维度会更新如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> d extra ′ = 2 ⌈ d 2 log β T tune 2 π ⌉ ≥ 2 ⌈ d 2 log 10000 T train 2 π ⌉ = d extra . \begin{equation}d'\text{extra}=2\left\lceil{\frac{d}{2}}\log{\beta}{\frac{T_\text{tune}}{2\pi}}\right\rceil\geq2\left\lceil{\frac{d}{2}}\log_{10000}{\frac{T_\text{train}}{2\pi}}\right\rceil=d_\text{extra}\text{.}\tag{16c}\end{equation} </math>dextra′=2⌈2dlogβ2πTtune⌉≥2⌈2dlog100002πTtrain⌉=dextra.(16c)
虽然如此,如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β 足够小,模型还是可以外推超过 <math xmlns="http://www.w3.org/1998/Math/MathML"> T tune T_\text{tune} </math>Ttune;特别地,如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β 小于如下的 <math xmlns="http://www.w3.org/1998/Math/MathML"> β 1 , β 2 , β 3 \beta_1,\beta_2,\beta_3 </math>β1,β2,β3,外推效果会得到显著提升。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> β 1 = 2 T tune π , β 2 = T tune π , β 3 = T tune 2 π . \begin{equation}\beta_1 = \frac{2 T_\text{tune}}{\pi}, \quad\beta_2 = \frac{T_\text{tune}}{\pi}, \quad\beta_3 = \frac{T_\text{tune}}{2\pi}\text{.}\tag{16d}\end{equation} </math>β1=π2Ttune,β2=πTtune,β3=2πTtune.(16d)
定理3 可以看做 定理1、引理1、定理2 的扩展和延伸,根据定理3 可以评估出 任意base的RoPE在任意长度续训时的外推表现。这其中,临界base是外推效果最差的base ,由此就回答了Q1 :取base为 <math xmlns="http://www.w3.org/1998/Math/MathML"> β 0 \beta_0 </math>β0时,模型仅根据原始临界维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d extra d_\text{extra} </math>dextra以内的维度去感知恰好是续训长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T tune T_\text{tune} </math>Ttune内的位置信息;base更小 更多的维度感知到位置信息 ,base更大 更长的位置信息可以被表示。
如果令 <math xmlns="http://www.w3.org/1998/Math/MathML"> T tune = T train T_\text{tune}=T_\text{train} </math>Ttune=Ttrain,那么根据公式16a,临界base恰好等于10000,对应2.3节所论述的,base=10000在原始长度续训时外推效果最差的现象。此时,公式16b对应3.3小节关于更大base的结果,公式16c对应3.2小节关于临界维度的探讨,公式16d对应3.1小节关于更小base的结果。如果令 <math xmlns="http://www.w3.org/1998/Math/MathML"> T tune > T train T_\text{tune}>T_\text{train} </math>Ttune>Ttrain,从外推效果角度,首先模型能够适应续训长度的更长词窗大小,模型外推能力一定上升。如图7所示,在16K上下文长度上上续训LLaMA2 7B/13B,无论base如何取值,支持上下文都大于等于16K,超过原始LLaMA2 7B/13B。
图7 不同base的RoPE在16K长文本上续训后的语言建模困惑度。
从周期感知角度,由于base和训练长度都改变了,因此首先需要确认是否有更多位置信息纳入到了超过临界维度的特征维度的训练过程中。参考临界维度的定义,可以根据更新后的base,计算出有多少维度的正余弦波在更新的上下文长度上遍历过一个周期内的取值,也就得到了临界base。如果base大于等于临界base,那么微调阶段能遍历周期的维度在训练阶段就已经可以遍历了一整个周期,因此模型外推的临界维度不变,参考定理2,外推上界取决于更新的base和原始的临界维度。对于在16K长续训的LLaMA2,如图7所示,临界base=71738,对于大于71738的base,例如80000、120000、1000000,外推上界都超过16K,并且base越大支持上下文越长。
如果base小于临界base,那么微调阶段遍历周期的维度超过原始临界维度,临界维度更新,
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> d ~ extra = 2 ⌈ d 2 log β T tune 2 π ⌉ ≤ 2 ⌈ d 2 log β 0 T tune 2 π ⌉ = 2 ⌈ d 2 log 10000 log T train 2 π T tune 2 π T tune 2 π ⌉ = 2 ⌈ d 2 1 log T train 2 π T tune 2 π log 10000 T tune 2 π ⌉ = 2 ⌈ d 2 log T tune 2 π T train 2 π log T tune 2 π 10000 ⌉ = 2 ⌈ d 2 log 10000 T train 2 π ⌉ = d extra . \begin{equation}\begin{aligned} \tilde{d}\text{extra}=2\left\lceil{\frac{d}{2}}\log{\beta}{\frac{T_\text{tune}}{2\pi}}\right\rceil&\leq2\left\lceil{\frac{d}{2}}\log_{\beta_0}{\frac{T_\text{tune}}{2\pi}}\right\rceil=2\left\lceil{\frac{d}{2}}\log_{{10000}^{\log_{\frac{T_\text{train}}{2\pi}}{\frac{T_\text{tune}}{2\pi}}}}{\frac{T_\text{tune}}{2\pi}}\right\rceil \\ &=2\left\lceil{\frac{d}{2}}\frac{1}{{\log_{\frac{T_\text{train}}{2\pi}}{\frac{T_\text{tune}}{2\pi}}}}\log_{10000}{\frac{T_\text{tune}}{2\pi}}\right\rceil=2\left\lceil{\frac{d}{2}}\frac{{\log_{\frac{T_\text{tune}}{2\pi}}{\frac{T_\text{train}}{2\pi}}}}{\log_{\frac{T_\text{tune}}{2\pi}}{10000}}\right\rceil \\ &=2\left\lceil{\frac{d}{2}}\log_{10000}{\frac{T_\text{train}}{2\pi}}\right\rceil=d_\text{extra} \end{aligned}\text{.}\tag{17}\end{equation} </math>d~extra=2⌈2dlogβ2πTtune⌉≤2⌈2dlogβ02πTtune⌉=2⌈2dlog10000log2πTtrain2πTtune2πTtune⌉=2⌈2dlog2πTtrain2πTtune1log100002πTtune⌉=2⌈2dlog2πTtune10000log2πTtune2πTtrain⌉=2⌈2dlog100002πTtrain⌉=dextra.(17)
但是由于该维度取决于续训长度,因此模型的外推上限仍然受续训长度限制。虽然如此,如果base足够小,使得模型的每个维度,入职前的定理1所示,在续训长度内遍历0到 <math xmlns="http://www.w3.org/1998/Math/MathML"> π / 2 \pi/2 </math>π/2或 <math xmlns="http://www.w3.org/1998/Math/MathML"> π \pi </math>π或 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 π 2\pi </math>2π的取值,那么模型的外推效果又会进一步提升,具体表现为困惑度上升曲线越来越平缓。对于在16K上下文续训的LLaMA2,如图7所示,对于小于71738的base,例如60000、20000、10000、500,曲线以16K位置为端点逐渐平坦。其中,虽然base=10000在原始长度续训上表现差强人意,但是由于在16K长度上续训cos/sin内输入已遍历至 <math xmlns="http://www.w3.org/1998/Math/MathML"> π / 2 \pi/2 </math>π/2,效果取得明显提升;至base=500时,模型已取得了和base=1000000一样平坦的困惑度曲线,实现了100K上下文的外推。
对于base=500在16K长度上续训得结果,由于其困惑度曲线足够平坦,笔者认为这在某种程度上已经打破 LM-Infinite 和 熵不变性 等研究中提出的,RoPE自注意力分布熵增导致外推效果变差的"诅咒"。笔者得到了一个大胆的猜想,当base足够小并且训练长度足够长, <math xmlns="http://www.w3.org/1998/Math/MathML"> q t , k s \bm{q}_t,\bm{k}_s </math>qt,ks的每个维度在训练阶段就已经感知到了位置编码的周期性变化,由此可以胜任近乎无限的外推。进一步使用更长的序列(通过多本书籍拼接)测试 base=500/1000000 在 128K、256k、512k、1M 长文档上的语言建模困惑度,得到如下表所示的结果。发现base=500显著优于base=1000000,并且结合后文提到的 对数校正,可以一直保持在相对低位,进一步回答了Q3 ,给出了几乎可以无限外推的方案:即将base缩小(或者结合续训长度放大),从而让自注意力头的每个维度在续训时经历多个周期,充分意识到位置编码的周期性(注:虽然每个维度的位置编码遍历了一个周期 ,出现了重复,但是整体并没有重复,而根据旋转角的取值,出现整体重复的公共周期非常大)。
4. 实验验证:RoPE外推的根本问题
4.1 临界维度的验证
为了深入分析临界维度和外推问题之间的因果关系,我们以下的实验。这些实验都是基于预训练或微调的LLaMA2 7B模型进行的,笔者提取了模型最后一层每个注意力头中 <math xmlns="http://www.w3.org/1998/Math/MathML"> q t , k 0 \bm{q}_t,\bm{k}_0 </math>qt,k0(对应attention score矩阵中的第一列),计算两个向量内积时的前92维和后36维的attention score随 <math xmlns="http://www.w3.org/1998/Math/MathML"> q t \bm{q}_t </math>qt下标 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t的变化规律(这里选择 <math xmlns="http://www.w3.org/1998/Math/MathML"> q t , k 0 \bm{q}_t,\bm{k}_0 </math>qt,k0,因为下标 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 恰好 既是q的绝对位置 也是qk的相对位置 同时是推理长度,并且前4096范围内的绝对位置和相对位置都是训练过程中见过的)。通过对比attention score的变化规律和语言建模困惑度的变化以及临界维度前后cos曲线的变化,笔者意图建立基于RoPE的大语言模型中 周期、临界维度、外推问题 之间的关联。
具体地,笔者首先比较了预训练的LLaMA2,如图8第1列所示,LLaMA2前92维的attention score无论在训练长度范围4096内外都波动都非常有限,而后36维的attention score在超出4096后立刻呈现出和4096以内不同的取值波动,与此同时困惑度曲线急速攀升;同时观察cos曲线也可以发现从第0/1维到第90/91维对应的周期保持整个周期在4096之内,而第92/93维到第126/127维对应的周期超出了4096。在此基础上如果不微调模型,仅将LLaMA2后36维的超过4096的位置编码固定于4096位置,即可以得到如图8第2列所示的结果,发现除了某个自注意力头数值异常外,attention score波动减弱,外推效果取得一定改进。而图8第3列则比较了LLaMA2使用dynamic NTK外推过程中,前92维和后36维attention score的变化情况,发现前92维数值保持稳定,后36维数值相较于直接外推变化进一步趋缓,attention score波动和外推效果变化情况一致;由此可以发现,临界维度和外推效果之间存在相关性。
图8 base=10000的RoPE,困惑度、attention score、临界维度前后旋转角周期之间的关系。
在此基础上,笔者继续比较了保留base=100000,但在原始训练长度上微调LLaMA2的结果。首先如图8第4列所示,直接微调的效果,除了attention score适应了目标domain的数据分布外,相较于无改进的预训练LLaMA2一样,前92维波动前后一致,后36维,在4096之外的波动和4096以内的波动呈现不同规律。在此基础上,笔者将LLaMA2每一层的qk后36维直接砍掉,再使用同样设定微调(注:attention温度系数,从 <math xmlns="http://www.w3.org/1998/Math/MathML"> 128 \sqrt{128} </math>128 变成 <math xmlns="http://www.w3.org/1998/Math/MathML"> 92 \sqrt{92} </math>92 )得到结果如图8第5列所示。可以发现缺少了后36维微调的外推效果,显著优于直接微调的外推效果,模型直接达到了16K以上的外推长度;该结果更加有力地证明了,临界维度和外推效果之间存在因果性,临界维度的存在,导致推理长度超过训练长度时的超出临界维度部分的attention score波动,限制了模型的外推上限,也证明了,从周期角度解释并改进基于RoPE的大语言模型外推表现是合理、正确、有效的。
在此基础上,为了分析调整base的条件下,临界维度和外推上限之间的因果关系,笔者观察了不同base取值下,前92维和后36维的attention score波动情况,如图9所示。对于缩小base,RoPE前92和后36维attension score随相对位置的变化情况,缩小base的attention score在训练长度范围内就学习到了来自cos/sin的波动;也正是由于这些波动在训练长度中就已经感知过了,模型在外推的过程中不会产生OOD的问题;并且base越小感知越充分,对应外推曲线越平坦。对于较大的base,可以发现,无论base如何取值前92维的attention score并不会随相对位置有明显的波动,但是后36维的波动是很明显的。通过比较困惑度上升曲线可以发现,一旦上下文长度超过临界维度规定的外推上限,后36维就会面对未曾见过的位置信息,对应到attention score上就是产生OOD的数值,同时困惑度开始急速攀升,模型外推失效。
图9 不同base在预训练长度上续训的RoPE,困惑度、attention score、临界维度前后周期之间的关系。
类似的规律也出现在16K长文本续训的场景下。笔者通过相同的方法,如图10所示,将不同base在16K长微调下,attention score的波动直观呈现。发现对于base=500,由于已经在训练阶段经历了足够多的波动,因此困惑度曲线非常平坦;对于base=10000,虽然在原始长度上续训会取得糟糕的表现,但由于更长长度的续训使得每个维度的cos/sin取值都可以含盖0到 <math xmlns="http://www.w3.org/1998/Math/MathML"> π / 2 \pi/2 </math>π/2的范围,因此后36维的attention score的波动也会有一定的改观,由此导致外推效果取得一些改进;对于base=40000,由于逐步靠近临界base,后36维微调阶段见过的位置信息随着base的增大进一步收窄;对于base=120000,临界维度降到92维,外推效果由前92维决定,根据公式16预测的外推上界与最大支持上下文长度一致;对于base=1000000,前92维周期进一步拉长,对应上下文长度进一步扩张至100K以外,但如3.4小节末尾的表格所示,仍然会在128K之后出现困惑度的快速上升,和理论预期一致。
图10 不同base在16K长度上续训的RoPE,困惑度、attention score、临界维度前后周期之间的关系。
4.2 相关工作的结合
本小节在验证了本文提出理论的正确性的基础上,探讨该理论对于测试阶段外推方法的指导价值。鉴于以下两点原因,测试阶段的外推方法仍然是必要的。一方面,与较大base的RoPE(例如base=1000000)相比,较小base的RoPE(例如base=500)的性能存在一定的落后,如图7所示。另一方面,如图2所示,对于base不够大的RoPE(例如base=320000),它仍然无法外推到100K或更长的上下文。为了进一步增强RoPE对更长上下文的适应性,以及不同base外推的特点,笔者在这里探讨对数校正、xPos(代表滑动窗口方法)、dynamic NTK(代表调整旋转角的方法)等对不同base的RoPE的外推改善效果。
对数校正是一种独立于第1节中两种外推流派的外推方法,在 从熵不变性看Attention的Scale操作 中被提出用来改进外推问题。对数校正的实现很简单,如公式18所示,直接将原始的attention score乘以当前推理长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t的对数。在已有研究中,对数的底一般取训练长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain。然而,根据RoPE缩放法则,最大支持上下文长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T extra T_\text{extra} </math>Textra内的attention score都是可靠的。所以笔者提出以 <math xmlns="http://www.w3.org/1998/Math/MathML"> T extra T_\text{extra} </math>Textra为对数底,设校正值的下限为1,意味着在外推上限内不需要额外的对数校正。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A t , s = R e [ ∑ n = 0 d / 2 − 1 q ~ t ( n ) k ~ s ( n ) ∗ e i ( t − s ) θ n ] ⋅ p t , p t = max ( 1 , log T extra t ) . \begin{equation}\bm{A}{t,s}=\mathrm{Re}\begin{bmatrix}{\sum\limits{n=0}^{d/2-1}\tilde{q}t^{\tiny(n)}\tilde{k}s^{\tiny(n)}{}^{*}e^{i(t-s)\theta_n}}\end{bmatrix}\cdot{p_t}, \quad p_t=\max\left(1, \log{T\text{extra}}{t}\right) \text{.}\tag{18}\end{equation} </math>At,s=Re[n=0∑d/2−1q~t(n)k~s(n)∗ei(t−s)θn]⋅pt,pt=max(1,logTextrat).(18)
如前文所述,滑动窗口及其变体,是一种广泛应用的外推策略;笔者这里使用xPos方法作为词窗方法的一个代表。相比于xPos最早用于预训练阶段,笔者在这里将其看做一个测试阶段的软词窗,讨论本文提出方案对词窗方法的指导意义(更主要的原因是,在本文的撰写过程中,FlashAttention2尚未兼容词窗方法)。需要注意的是,本文在沿用了 原始论文 中对于xPos的定义外,使用外推上限 <math xmlns="http://www.w3.org/1998/Math/MathML"> T extra T_\text{extra} </math>Textra作为指数校正的分母而不是原始分母512,公式如下。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A t , s = R e [ ∑ n = 0 d / 2 − 1 q ~ t ( n ) k ~ s ( n ) ∗ ζ n t − s T extra e i ( t − s ) θ n ] , ζ n = γ + 2 n / d γ + 1 , γ = 0.4 . \begin{equation}\bm{A}{t,s}=\mathrm{Re}\begin{bmatrix}{\sum\limits{n=0}^{d/2-1}\tilde{q}_t^{\tiny(n)}\tilde{k}s^{\tiny(n)}{}^{*}\zeta_n^{\frac{t-s}{T\text{extra}}}e^{i(t-s)\theta_n}}\end{bmatrix}, \quad \zeta_n=\frac{\gamma+2n/d}{\gamma+1}, \ \gamma=0.4 \text{.}\tag{19}\end{equation} </math>At,s=Re[n=0∑d/2−1q~t(n)k~s(n)∗ζnTextrat−sei(t−s)θn],ζn=γ+1γ+2n/d, γ=0.4.(19)
如前文所述,dynamic NTK是当前非常重要的一个测试阶段的外推方法,笔者这里也讨论了dynamic NTK与本文方法的结合。笔者对此同样做出一些修改:先将原始RoPE中的base=10000替换为微调阶段的 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β,再使用外推上限 <math xmlns="http://www.w3.org/1998/Math/MathML"> T extra T_\text{extra} </math>Textra取代原始的训练长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain,公式如下。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A t , s = R e [ ∑ n = 0 d / 2 − 1 q ~ t ( n ) k ~ s ( n ) ∗ e i ( t − s ) θ n ] , θ n = ( β ⋅ α t ) − 2 n / d α t = max ( 1 , 2 ⌈ log 2 t T extra ⌉ + 1 − 1 ) . \begin{equation}\bm{A}{t,s}=\mathrm{Re}\begin{bmatrix}{\sum\limits{n=0}^{d/2-1}\tilde{q}_t^{\tiny(n)}\tilde{k}s^{\tiny(n)}{}^{*}e^{i(t-s)\theta_n}}\end{bmatrix}, \quad \begin{gathered} \theta_n={\left(\beta\cdot\alpha_t\right)}^{-2n/d} \\[1ex] \alpha_t=\max\left(1, 2^{\left\lceil\log_2{\frac{t}{T\text{extra}}}\right\rceil+1}-1\right) \end{gathered} \text{ .}\tag{20}\end{equation} </math>At,s=Re[n=0∑d/2−1q~t(n)k~s(n)∗ei(t−s)θn],θn=(β⋅αt)−2n/dαt=max(1,2⌈log2Textrat⌉+1−1) .(20)
笔者在LLaMA2 7B上对这上述方法进行实验,得到如图11所示的结果。首先看基于原始RoPE的LLaMA2的结果,如第1行第1列所示。很明显,此时对数校正对于LLaMA2几乎没有任何作用。第1行剩余的子图展示了base小于10000的RoPE的结果。可以发现,随着base的减小,对数校正带来的改进更占主导地位,而dynamic NTK的改进逐渐收窄。对于base=500的RoPE,对数校正的困惑度曲线足够平坦,一方面,表明其支持100K上下文长度的外推能力,另一方面,也证明了 <math xmlns="http://www.w3.org/1998/Math/MathML"> q t , k s \bm{q}_t,\bm{k}_s </math>qt,ks中临界维度的存在RoPE外推问题的根本原因(此时临界维度=自注意力头维度)。 相反,此时dynamic NTK呈现出明显的副作用;因此,只要在训练阶段学到的位置信息足够可靠,可供基于RoPE的大语言模型进一步外推,无需额外的内部信息重组去感知训练长度之外的位置信息。
图11 不同base的RoPE,对数校正、xPos、dynamic NTK对其在测试阶段的改进效果。第2行图中,dynamic NTK (wrong) 表示使用训练长度进行NTK的结果,dynamic NTK (correct) 表示使用本文方法求出外推上界后进行NTK的结果。
接着看图11的第2行,这里展示了微调时base大于10000的RoPE的结果,例如40000、160000、400000和600000。笔者没有在base=1000000的RoPE上测试上述方法的性能,因为它已经达到了100K的上下文长度。对于其他base较大的RoPE,dynamic NTK获得的外推性能的提升更为显着,相关原理已经在图8中进行了可视化,并在4.1小节中进行了讨论。除此之外,需要额外注意的一点是,用外推上限 <math xmlns="http://www.w3.org/1998/Math/MathML"> d extra d_\text{extra} </math>dextra替换 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain对于具有较大base的RoPE非常重要。 例如,如果基于 <math xmlns="http://www.w3.org/1998/Math/MathML"> T train T_\text{train} </math>Ttrain进行dynamic NTK,当base足够大(如400000和600000)时,改进将受到限制,甚至被破坏。这一现象再次证明了本文对其他工作的指导价值。总之,对于小于缩放法则中定义的 <math xmlns="http://www.w3.org/1998/Math/MathML"> β 3 \beta_3 </math>β3的base,由于每个维度都学习了完整的位置信息,对数校正即足以增强外推;对于大于缩放法则中 <math xmlns="http://www.w3.org/1998/Math/MathML"> β 0 \beta_0 </math>β0的base(如果是原始长度上续训,即是10000),dynamic NTK则是是外推到更长上下文的好帮手;此外,词窗方法对于两者都可以起到很好的辅助作用。
4.3 生成内容的验证
最后笔者进行了文本续写任务的测试。针对不同base和续训长度下微调的LLaMA2 7B,输入32K长度的上文让其通过贪心方法生成续写,得到结果去除特殊字符后详见论文附录E(篇幅原因不列入知乎博客中)。可以发现无论base取500还是100万,都可以生成很流畅的文本,没有语法错误、具有一定条理。对于base=500,如果使用更长长度续训或者在测试阶段加入如前文所述的对数校正,会取得更好的续写效果。由于base=1000000时外推存在一个严格的上界,因此有理由相信,base=500的微调模型,具有可以无限外推的潜力。
5. 写在最后
在本文的最后,笔者简要回顾一下 自己撰写这篇文章的最终目的和额外思考,以及工作本身的不足和后续改进的方向。这篇文章是笔者的第一篇论文、第一篇正式的工作(抛开之前在知乎上发表的 傅里叶可外推位置编码 FEPE),同时完成也比较仓促,不周之处,还望诸位看官海涵。
如前文所述:与以往的外推研究不同,本文并没有聚焦一个具体的外推方案,而是给出了一套改进RoPE外推能力的框架 (从周期角度出发,给出外推问题的来源以及改进方法),及其对应的数学解释(对应第3节,重点3.2小节,临界维度)、实验验证(对应第4节,重点4.1小节,观察 attention score);在这个框架下,本文不仅给出了 任意base 任意续训长度时 模型外推表现如何 (对应3.4小节 扩展的RoPE外推缩放法则 定理3),同时给出了 给定期望上下文长度时应该如何调整RoPE实现定长外推 (对应3.3小节,根据公式15确定更大的base),没有给定期望上下文长度时应该如何调整RoPE实现不定长外推(对应3.1小节,根据公式9确定更小的base)。
在本文中,放大和缩小base都能改进RoPE外推效果是一个非常有意思的现象,进一步地,16K续训base=500的RoPE虽然在100K以内困惑度略高于base=1000000,但是在128K以上的上下文内困惑度都低于base=1000000;小base虽然整体效果差,但是没有效果显著的崩坏,并且结合对数校正还能显著改进,大base虽然一时效果好,但是存在严格的外推上界。在和论文指导老师讨论的过程中,老师提到一个有意思的点,即机器学习基础中的偏差方差分解:base=500 对应 高方差-低偏差,泛化能力更强 ,base=1000000 对应 低方差-高偏差,拟合效果更好。当然这目前只是一个想法,后续笔者也会探讨这其中是否存在一个这样的关联。
论及本文的不足,虽然本文确实主打研究RoPE外推的普遍规律,但是缺少下游任务的详尽测试是不可避免的话题;虽然论文中给出了短上下文任务的一些验证,但是长上下文任务验证显然还是需要考虑的。对此笔者还会通过改进微调语料配比,探究放大缩小base后外推,在下游任务上效果如何,是否存在一些规律。此外,本文中笔者使用 <math xmlns="http://www.w3.org/1998/Math/MathML"> q t T k 0 \bm{q}_t^T\bm{k}_0 </math>qtTk0来衡量推理过程中attention score的变化情况,应该说这个切入口选得很好:这列数值正好对应了箭形词窗中保留的attention矩阵的第一列,同时对于这一列,语言建模长度、相对位置、绝对位置 ,三者在此时恰好都是 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t。因此,本文缺少对于其他位置的attention score,以及这些score和外推效果之间关联性的讨论。此外直觉上,base不能无限缩小和放大,太小不同维度的周期会出现一个公共周期,让模型无法正确刻画不同位置,太大旋转角趋近于0相当于没有进行位置编码,这些也都需要后续实验的不断探索。
本文著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。