通过位置插值增加 transformer context 长度

模型 context 长度的问题

如何高效的增加 transformer 的 context 长度,使得模型能够处理超长文本是 LLM 训练和推理过程中十分关注的问题。因为 2K 的长度很难满足一些场景的需求。虽然可以使用 map-reduce 等方法在有限的长度下处理超长的文本,但是最理想的情况肯定是模型能够直接处理超长文本。Extending Context Window of Large Language Models via Positional Interpolation 通过位置插值(Positional Interpolation,PI)的方式解决了扩展模型 context 长度的问题,并且取得了不错的效果。具体的实现可以参考 rotary-embedding-torch 的代码。

常见的编码方式

对于 NLP 任务,如果想要获得理想的结果,我们需要让 transformer 感知 token 的位置信息,常见方式为一下两种:

  1. RoPE: Rotary Position Embedding
  2. Absolute Position Encodings

RoPE

RoPE 简单来说是希望捕捉 token 之间的相对距离关系,例如下面的三个 token 组成的句子: A B C 假设所有的 token 的 embedding 都相同,也就是先不考虑 token embedding 之间的相似度,只考虑距离关系,希望获得下面的矩阵:

0 1 2
-1 0 1
-2 -1 0

也就是 A 和 C 的距离是 2,B 和 C 的距离是 1。为了获得相对距离关系的表示,RoPE: Rotary Position Embedding 提出使用 Rotary Position 。考虑还是一个由三个 token 组成的句子:A A A (句子中的词语都相同,全部是 A)。并且 A 的 embedding 是向量 <1, 1> ,那么每次将 A 旋转 90 度(和旋转矩阵相乘),就得到了 A 在不同位置的表示:

A <1, 1>

A <-1, 1>

A <-1, -1>

那么点积后就得到了捕捉了相对距离关系的相似度矩阵(如果调整角度为非 90 度,就可以得到上面的相似度矩阵)

1 0 -1
0 1 0
-1 0 1

Absolute Position

Absolute Position 是 transformer 最开始采用的方法。简单来说是取几条频率不同的 sin 曲线,根据不同的位置选取对应的值就得到了位置编码。通过将位置编码和 embedding 相加,希望模型在计算的过程中能够根据不同 sin 曲线的取值去捕捉 token 的位置信息。下面图中不同颜色的曲线代表频率不同的 sin 曲线,不同颜色的点分别代表在不同位置采样到的值。

编码存在的问题与解决方案

编码问题

如果训练的长度为 3 那么模型只见过长度在 3 以内的 Position embedding。长度超过 3 的位置信息,模型就不知道如何处理了。可以看论文中的图片 Figure 2。右边的图片表明,在长度小于等于 2048 的时候,attention score 的得分基本正常(可以看左侧的图片,注意力得分的值域在 [-3, 3] ,在模型可以处理的正常范围);但是当长度超过 2048 后,注意力得分开始剧烈波动,远远超过了正常的范围,模型的表现也必然受到较大的负面影响。

解决方案

如果训练的长度是有限的,但是希望模型能够在预测阶段处理更长文本,那么可以考虑在已有的范围内进行插值。看论文中的图片 Figure 1,如果训练的长度只有 2048 ,那么通过对采样的位置进行"压缩",将超过 2048 的点压进来,就可以得到在训练范围内,长度超过训练范围的位置编码表示。这样 context window 的长度就实现了翻倍。也就是可以通过插值(Interpolation)利用已经训练的编码参数,而不是对编码参数进行推断(Extrapolation)。

具体实现

如果当前模型,训练的 context 最大长度为 L,现在希望模型可以处理最大长度为 L' 的输入。对于位置在 m 的 token,可以通过下面的方式获得位置 m 的插值编码方式:

<math xmlns="http://www.w3.org/1998/Math/MathML"> f ′ ( x , m ) = f ( x , m L L ′ ) f^{'}(x, m) = f(x, \frac{mL}{L^{'}}) </math>f′(x,m)=f(x,L′mL)

以 Rotary Position 为例,如果当前的模型最大的长度依旧为 3 ,之前的三个位置分别是旋转 90 度得到的。现在如果想将长度扩展为 6 ,那么通过修改旋转矩阵的参数,每次的旋转 45 度即可,也就是对于之前的句子 A A A 在不同位置的表示变为:

A <1, 1>

A <0, 1>

A <-1, 1>

效果

论文中使用用于评估模型在长文本上的表现的 PG-19 数据集来测评位置插值的效果。评估结果可以参考论文 Table 1。如果不做任何操作(也就是 None 的那一行)由于模型能使用的 context 长度只有 2048 ,对于更长的位置上的信息无法利用,导致效果非常的差。而使用位置插值(PI 行)不仅效果优于 finetune(FT 那一行) ,并且能够处理更长的序列长度。finetune 和 位置插值 做法的差别为:

  • finetune:在训练好的 context 长度为 L 的模型基础上,使用少量 context 长度更长的文本来微调长度为 L' 的模型,使用方法是 extrapolation。
  • 位置插值:在训练好的 context 长度为 L 的模型基础上,使用少量 context 长度更长的文本来训练长度为 L' 的模型,使用方法是 interpolation。在实际应用中,选取适合的数据对模型进行 PI finetune 也十分的重要。

参考资料

www.together.ai/blog/llama-...
www.youtube.com/watch?v=oyX...
erdem.pl/2021/05/und...

相关推荐
我爱学Python!20 小时前
基于 LangChain 的自动化测试用例的生成与执行
人工智能·自然语言处理·langchain·自动化·llm·测试用例·大语言模型
牛右刀薛面1 天前
launcher.py: error: the following arguments are required: --output_dir
llm·sft·llamafactory
JasonLiu19192 天前
论文推荐 |【Agent】自动化Agent设计系统
人工智能·自动化·llm·agent·智能体
ulimpid2 天前
LLM | Xinference 安装使用(支持CPU、Metal、CUDA推理和分布式部署)
llm·xinference
伊织code2 天前
GraphRAG-Local-UI - 基于 GraphRAG 支持本地的聊天UI
ui·llm·rag·graphrag·local-ui
AI_小站3 天前
图解大模型计算加速系列:vLLM源码解析1,整体架构
人工智能·深度学习·架构·llm·大语言模型·ai大模型·vllm
强哥之神4 天前
一文了解:最新版本 Llama 3.2
人工智能·深度学习·机器学习·计算机视觉·语言模型·llm·llama
少喝冰美式4 天前
深度学习 Transformer 的标签平滑(Label Smoothing)
人工智能·深度学习·llm·transformer·大语言模型·ai大模型·计算机技术
xiaohezi5 天前
如何用 30秒和 5 行代码写个 RAG 应用?
人工智能·llm·aigc