模型 context 长度的问题
如何高效的增加 transformer 的 context 长度,使得模型能够处理超长文本是 LLM 训练和推理过程中十分关注的问题。因为 2K 的长度很难满足一些场景的需求。虽然可以使用 map-reduce 等方法在有限的长度下处理超长的文本,但是最理想的情况肯定是模型能够直接处理超长文本。Extending Context Window of Large Language Models via Positional Interpolation 通过位置插值(Positional Interpolation,PI)的方式解决了扩展模型 context 长度的问题,并且取得了不错的效果。具体的实现可以参考 rotary-embedding-torch 的代码。
常见的编码方式
对于 NLP 任务,如果想要获得理想的结果,我们需要让 transformer 感知 token 的位置信息,常见方式为一下两种:
RoPE
RoPE 简单来说是希望捕捉 token 之间的相对距离关系,例如下面的三个 token 组成的句子: A B C 假设所有的 token 的 embedding 都相同,也就是先不考虑 token embedding 之间的相似度,只考虑距离关系,希望获得下面的矩阵:
0 | 1 | 2 |
---|---|---|
-1 | 0 | 1 |
-2 | -1 | 0 |
也就是 A 和 C 的距离是 2,B 和 C 的距离是 1。为了获得相对距离关系的表示,RoPE: Rotary Position Embedding 提出使用 Rotary Position 。考虑还是一个由三个 token 组成的句子:A A A (句子中的词语都相同,全部是 A)。并且 A 的 embedding 是向量 <1, 1> ,那么每次将 A 旋转 90 度(和旋转矩阵相乘),就得到了 A 在不同位置的表示:
A <1, 1>
A <-1, 1>
A <-1, -1>
那么点积后就得到了捕捉了相对距离关系的相似度矩阵(如果调整角度为非 90 度,就可以得到上面的相似度矩阵)
1 | 0 | -1 |
---|---|---|
0 | 1 | 0 |
-1 | 0 | 1 |
Absolute Position
Absolute Position 是 transformer 最开始采用的方法。简单来说是取几条频率不同的 sin 曲线,根据不同的位置选取对应的值就得到了位置编码。通过将位置编码和 embedding 相加,希望模型在计算的过程中能够根据不同 sin 曲线的取值去捕捉 token 的位置信息。下面图中不同颜色的曲线代表频率不同的 sin 曲线,不同颜色的点分别代表在不同位置采样到的值。
编码存在的问题与解决方案
编码问题
如果训练的长度为 3 那么模型只见过长度在 3 以内的 Position embedding。长度超过 3 的位置信息,模型就不知道如何处理了。可以看论文中的图片 Figure 2。右边的图片表明,在长度小于等于 2048 的时候,attention score 的得分基本正常(可以看左侧的图片,注意力得分的值域在 [-3, 3] ,在模型可以处理的正常范围);但是当长度超过 2048 后,注意力得分开始剧烈波动,远远超过了正常的范围,模型的表现也必然受到较大的负面影响。
解决方案
如果训练的长度是有限的,但是希望模型能够在预测阶段处理更长文本,那么可以考虑在已有的范围内进行插值。看论文中的图片 Figure 1,如果训练的长度只有 2048 ,那么通过对采样的位置进行"压缩",将超过 2048 的点压进来,就可以得到在训练范围内,长度超过训练范围的位置编码表示。这样 context window 的长度就实现了翻倍。也就是可以通过插值(Interpolation)利用已经训练的编码参数,而不是对编码参数进行推断(Extrapolation)。
具体实现
如果当前模型,训练的 context 最大长度为 L,现在希望模型可以处理最大长度为 L' 的输入。对于位置在 m 的 token,可以通过下面的方式获得位置 m 的插值编码方式:
<math xmlns="http://www.w3.org/1998/Math/MathML"> f ′ ( x , m ) = f ( x , m L L ′ ) f^{'}(x, m) = f(x, \frac{mL}{L^{'}}) </math>f′(x,m)=f(x,L′mL)
以 Rotary Position 为例,如果当前的模型最大的长度依旧为 3 ,之前的三个位置分别是旋转 90 度得到的。现在如果想将长度扩展为 6 ,那么通过修改旋转矩阵的参数,每次的旋转 45 度即可,也就是对于之前的句子 A A A 在不同位置的表示变为:
A <1, 1>
A <0, 1>
A <-1, 1>
效果
论文中使用用于评估模型在长文本上的表现的 PG-19 数据集来测评位置插值的效果。评估结果可以参考论文 Table 1。如果不做任何操作(也就是 None 的那一行)由于模型能使用的 context 长度只有 2048 ,对于更长的位置上的信息无法利用,导致效果非常的差。而使用位置插值(PI 行)不仅效果优于 finetune(FT 那一行) ,并且能够处理更长的序列长度。finetune 和 位置插值 做法的差别为:
- finetune:在训练好的 context 长度为 L 的模型基础上,使用少量 context 长度更长的文本来微调长度为 L' 的模型,使用方法是 extrapolation。
- 位置插值:在训练好的 context 长度为 L 的模型基础上,使用少量 context 长度更长的文本来训练长度为 L' 的模型,使用方法是 interpolation。在实际应用中,选取适合的数据对模型进行 PI finetune 也十分的重要。
参考资料
www.together.ai/blog/llama-...
www.youtube.com/watch?v=oyX...
erdem.pl/2021/05/und...