通过位置插值增加 transformer context 长度

模型 context 长度的问题

如何高效的增加 transformer 的 context 长度,使得模型能够处理超长文本是 LLM 训练和推理过程中十分关注的问题。因为 2K 的长度很难满足一些场景的需求。虽然可以使用 map-reduce 等方法在有限的长度下处理超长的文本,但是最理想的情况肯定是模型能够直接处理超长文本。Extending Context Window of Large Language Models via Positional Interpolation 通过位置插值(Positional Interpolation,PI)的方式解决了扩展模型 context 长度的问题,并且取得了不错的效果。具体的实现可以参考 rotary-embedding-torch 的代码。

常见的编码方式

对于 NLP 任务,如果想要获得理想的结果,我们需要让 transformer 感知 token 的位置信息,常见方式为一下两种:

  1. RoPE: Rotary Position Embedding
  2. Absolute Position Encodings

RoPE

RoPE 简单来说是希望捕捉 token 之间的相对距离关系,例如下面的三个 token 组成的句子: A B C 假设所有的 token 的 embedding 都相同,也就是先不考虑 token embedding 之间的相似度,只考虑距离关系,希望获得下面的矩阵:

0 1 2
-1 0 1
-2 -1 0

也就是 A 和 C 的距离是 2,B 和 C 的距离是 1。为了获得相对距离关系的表示,RoPE: Rotary Position Embedding 提出使用 Rotary Position 。考虑还是一个由三个 token 组成的句子:A A A (句子中的词语都相同,全部是 A)。并且 A 的 embedding 是向量 <1, 1> ,那么每次将 A 旋转 90 度(和旋转矩阵相乘),就得到了 A 在不同位置的表示:

A <1, 1>

A <-1, 1>

A <-1, -1>

那么点积后就得到了捕捉了相对距离关系的相似度矩阵(如果调整角度为非 90 度,就可以得到上面的相似度矩阵)

1 0 -1
0 1 0
-1 0 1

Absolute Position

Absolute Position 是 transformer 最开始采用的方法。简单来说是取几条频率不同的 sin 曲线,根据不同的位置选取对应的值就得到了位置编码。通过将位置编码和 embedding 相加,希望模型在计算的过程中能够根据不同 sin 曲线的取值去捕捉 token 的位置信息。下面图中不同颜色的曲线代表频率不同的 sin 曲线,不同颜色的点分别代表在不同位置采样到的值。

编码存在的问题与解决方案

编码问题

如果训练的长度为 3 那么模型只见过长度在 3 以内的 Position embedding。长度超过 3 的位置信息,模型就不知道如何处理了。可以看论文中的图片 Figure 2。右边的图片表明,在长度小于等于 2048 的时候,attention score 的得分基本正常(可以看左侧的图片,注意力得分的值域在 [-3, 3] ,在模型可以处理的正常范围);但是当长度超过 2048 后,注意力得分开始剧烈波动,远远超过了正常的范围,模型的表现也必然受到较大的负面影响。

解决方案

如果训练的长度是有限的,但是希望模型能够在预测阶段处理更长文本,那么可以考虑在已有的范围内进行插值。看论文中的图片 Figure 1,如果训练的长度只有 2048 ,那么通过对采样的位置进行"压缩",将超过 2048 的点压进来,就可以得到在训练范围内,长度超过训练范围的位置编码表示。这样 context window 的长度就实现了翻倍。也就是可以通过插值(Interpolation)利用已经训练的编码参数,而不是对编码参数进行推断(Extrapolation)。

具体实现

如果当前模型,训练的 context 最大长度为 L,现在希望模型可以处理最大长度为 L' 的输入。对于位置在 m 的 token,可以通过下面的方式获得位置 m 的插值编码方式:

<math xmlns="http://www.w3.org/1998/Math/MathML"> f ′ ( x , m ) = f ( x , m L L ′ ) f^{'}(x, m) = f(x, \frac{mL}{L^{'}}) </math>f′(x,m)=f(x,L′mL)

以 Rotary Position 为例,如果当前的模型最大的长度依旧为 3 ,之前的三个位置分别是旋转 90 度得到的。现在如果想将长度扩展为 6 ,那么通过修改旋转矩阵的参数,每次的旋转 45 度即可,也就是对于之前的句子 A A A 在不同位置的表示变为:

A <1, 1>

A <0, 1>

A <-1, 1>

效果

论文中使用用于评估模型在长文本上的表现的 PG-19 数据集来测评位置插值的效果。评估结果可以参考论文 Table 1。如果不做任何操作(也就是 None 的那一行)由于模型能使用的 context 长度只有 2048 ,对于更长的位置上的信息无法利用,导致效果非常的差。而使用位置插值(PI 行)不仅效果优于 finetune(FT 那一行) ,并且能够处理更长的序列长度。finetune 和 位置插值 做法的差别为:

  • finetune:在训练好的 context 长度为 L 的模型基础上,使用少量 context 长度更长的文本来微调长度为 L' 的模型,使用方法是 extrapolation。
  • 位置插值:在训练好的 context 长度为 L 的模型基础上,使用少量 context 长度更长的文本来训练长度为 L' 的模型,使用方法是 interpolation。在实际应用中,选取适合的数据对模型进行 PI finetune 也十分的重要。

参考资料

www.together.ai/blog/llama-...
www.youtube.com/watch?v=oyX...
erdem.pl/2021/05/und...

相关推荐
真忒修斯之船3 小时前
大模型分布式训练并行技术(三)流水线并行
面试·llm·aigc
SpikeKing4 小时前
LLM - 使用 LLaMA-Factory 微调大模型 环境配置与训练推理 教程 (1)
人工智能·llm·大语言模型·llama·环境配置·llamafactory·训练框架
数据智能老司机1 天前
LLM工程师手册——监督微调
深度学习·架构·llm
AI_小站1 天前
LLM——10个大型语言模型(LLM)常见面试题以及答案解析
人工智能·程序人生·语言模型·自然语言处理·大模型·llm·大模型面试
waiting不是违停1 天前
LangChain Ollama实战文献检索助手(二)少样本提示FewShotPromptTemplate示例选择器
langchain·llm·ollama
我爱学Python!1 天前
AI Prompt如何帮你提升论文中的逻辑推理部分?
人工智能·程序人生·自然语言处理·chatgpt·llm·prompt·提示词
AI_小站2 天前
多模态大模型微调实践!PAI+LLaMA Factory搭建AI导游
人工智能·程序人生·语言模型·大模型·llm·产品经理·多模态大模型
AI_小站2 天前
【AI工作流】FastGPT - 深入解析FastGPT工作流编排:从基础到高级应用的全面指南
人工智能·程序人生·语言模型·大模型·llm·fastgpt·大模型应用
蚝油菜花3 天前
MeetingMind:AI 会议助手,支持自动转录音频并提取会议中的关键信息
人工智能·开源·llm
Agile.Zhou4 天前
给 Ollama 穿上 GPT 的外衣
llm·ollama