如何扩展大模型的上下文长度

一、背景

大模型的上下文长度是指我们在使用大模型的时候,给大模型的输入加上输出的字符(Token)总数,这个数字会被限制,如果超过这个长度的字符会被大模型丢弃。目前开源的大模型上下文长度一般不长,比如 Llama 2 只有 4K,Code-Llama 系列因为需要输入代码,扩展到了 16K。闭源系列模型的提供了更长的上下文长度,比如 OpenAI 在其最新模型 GPT-4 Turbo 中提供了 128K 的上下文长度,Anthropic 的 Claude 2.1 模型提供了 200K 上下文长度。

一些场景需要较长上下文,比如,文档翻译需要将整篇文档输入给大模型进行翻译,长文档内容抽取需要大模型读取整篇长文档进行内容抽取,会议内容总结则需要给大模型输入会议聊天记录进行总结等。

想要得到一个长上下文的大模型,一般有两种途径。一种是大模型在初始阶段被设置为长上下文,然后经过预训练,指令微调,对齐训练等方式得到一个长上下文大模型。另外一种方式是选择已经训练好的大模型,通过技术改造扩展其上下文长度,然后再进行微调训练得到长上下文模型。

本文将基于比较火的 Llama 2 大模型的结构[1]介绍上下文长度的方法与挑战,然后探讨一些业界流行的上下文长度扩展的技术,最后给大家推荐下 KubeAI 大模型训练推理平台可以上手实验。

二、LLAMA的结构

Transformer的结构

通常所说的大模型是指大语言模型(Large Language Model,LLM),其模型结构一般基于 Transformer 进行改进而来。Transformer 源自于 2017 年 Google 发表的著名论文"Attention Is All You Need[2]"。论文中的 Transformer 结构如下,下图源自论文[2]。

它的结构包括两部分:Encoder(编码器)和Decoder(解码器)。Encoder 与 Decoder 大致都包含以下层,每一层都有特定的功能,下面为 Encoder(编码器)各层的简单介绍:

  • 输入嵌入层(Input Embedding Layer):将输入文本的词或标记转换为向量表示,以便模型能够理解它们。
  • 多头自注意力层(Multi-Head Self-Attention Layer):帮助模型捕捉输入序列中词与词之间的关系,使模型能够了解上下文信息。
  • 前馈神经网络层(Feed-Forward Neural Network Layer):对多头自注意力的输出进行进一步的特征提取和变换,以增加模型的表示能力。
  • 归一化层(Layer Normalization Layer):规范化每一层的输出,有助于训练过程的稳定性。

总的来说,Transformer 是一种强大的模型,它可以捕捉文本和序列数据中的长距离依赖关系,使其在翻译、对话、摘要生成等自然语言处理任务中表现出色。这个模型已经在各种应用中取得了显著的成功。感兴趣的同学可以自行去网上搜索下 Transformer 的结构,深入了解。

LLAMA的结构

目前大多数生成式语言模型如 Llama 系列,仅仅采用了 Transformer 的 Decoder 模块结构。在 Huggingface 中,这种结构通常被称为 CausalLM,即因果语言模型(Causal Language Model)。下面我们来具体看一下 Llama 2 模型系列的结构,Llama 2 相关的论文[1]。

(下图是基于 Transfomers 代码中的 LlamaModel 绘制而成,具体代码参考 Transfomer 中的 modeling_llama.py[3])

我们来解读下上面 Llama 各层的结构与作用,首先从输入文本开始。会经过下面各层:

  • Input Embedding:将 Input 文本转化为向量表,通过 nn.Embedding 实现。
  • Llama Decoder Layer:Decoder 采用多层 Llama Decoder Layer。每一层包括自注意力(Llama Attention)和前馈网络(Llama MLP)。自注意力用于捕捉文本中的长程依赖关系。前馈网络进行非线性映射。
  • Llama RMSNorm:一种规范化方式,用于正则化每层的输,起到预处理的作用。
  • lm_head:一个线性层,将 Decoder 最后一层的输出映射到词典大小的维,以进行后续的语言模型 Logits 计算。
  • Llama Attention:多头自注意力机制,用于建模文本中的依赖关系。将输入表示切分为多个头,然后在每个头内做点积注意力运算。
  • Llama MLP:采用 Gated Linear Units 的多层前馈网络。进行非线性变换来捕捉复杂模式。

总体上,Llama 通过堆叠多层自注意力和前馈网络来表示文本语义,然后预测后续词元。lm_head 负责将语义表示映射为具体的词典 Logits。整个模型端到端通过语言模型目标进行训练。

上面各层中比较核心的是 Llama Attention 层,该层的结构如下:

论文"Attention Is All You Need"[2]中描述的 Attention 的计算公式如下:

Llama 的 Attention 计算过程如下:

  • 输入会经过线性变换,得到 Query(Q)、Key(K)和 Value(V)矩阵。
  • 对 Q 和 K 应用 RoPE 位置编码。RoPE 包含旋转的 Sin 和 Cos 编码,会根据每个 Token 的位置对其表示进行旋转。
  • 用旋转后的 Q 和 K 计算点积,得到注意力权重 Attention Score,经过 Softmax 计算后得到 Normalized Attention Weight。
  • 再把 Attention Weight 与 V 相乘,并进行加权求和,得到 Attention 的输出。
  • 输出再经过一个线性变换,继续输出给下一层作为输入。

这样通过 RoPE 位置编码、加权平均,Llama 的 Attention 可以高效稳定地提取文本序列的上下文语义信息。

三、扩展方案与挑战

位置编码层(RoPE)

通过上面对 Llama 结构的解析,我们看到在 Llama Attention 中有一个叫做 RoPE(旋转位置编码)的层,主要用于对输入进行位置编码,让模型学到输入文本中每个 Token 的位置关系,从而更好地理解输入。RoPE 层能处理序列的长度决定了 Llama 的上下文长度,要扩展 Llama 的上下文长度,需要对 RoPE 层进行改造和扩展。下面我们先简单介绍下 Llama RoPE 层的工作原理。

RoPE 旋转位置编码,最早来自 RoFormer:Enhanced Transformer with Rotary Position Embedding[4]这篇论文。下图源自论文[4]。

RoPE 层是一种相对位置编码方法,它给输入的每个 Token 编码一个向量,向量中的每个值表示该 Token 与其他 Token 的相对距离。论文以二维向量为例,解释了这种位置编码为什么叫做旋转位置编码。如上图所示,在二维平面,相当于把向量旋转了一个 Q 的角度。

论文中证明,在进行旋转位置编码之后,可以从新的编码向量中获取原向量的相对位置信息,即论文中下面的公式中的 m-(位置 m 减去位置 n)。下图源自论文[4]。

而我们只需要理解旋转位置编码的最终计算公式如下,对于一个输入向量 X,直接与一个 COS 矩阵和 SIN 矩阵进行内积后求和:下图源自论文[4]。

其中 WCOS 和 WSIN 分别是一个预先固定的 COS 矩阵和 SIN 矩阵。

下面我们展示下 Huggingface 的 Transformer 对应的计算 COS 和 SIN 矩阵的计算代码:

扩展位置编码层(RoPE)支持长上下文

上面介绍了 Llama 结构中的旋转位置编码层 RoPE。要扩展大模型的上下文长度,就需要扩展 RoPE 层,也就是扩展其 COS 和 SIN 矩阵,让RoPE支持更长序列的输入。

RoPE 中的 COS 和 SIN 矩阵维度(seq_length, embed_dim),其中 seq_length 就是模型支持的最大序列长度,embed_dim 是词嵌入维度。矩阵中的每个值表示一个位置上的正弦或余弦编码。为了支持更长的上下文,需要重新计算更大尺寸的 COS 和 SIN 矩阵。

对于未训练过的大模型,只需要直接更改其配置文件中的 max_position_embeddings 即可实现 RoPE 层的扩展,然后再进行训练。但是对于已经训练过的模型,如果直接修改其配置,会导致模型的效果急剧下降,后面第四部分我们会介绍一些基于已有模型进行改造扩展的上下文的方法。max_position_embeddings 的配置如下:

超长上下文面临的挑战

超长上下文的大模型部署推理的时候,往往会面临如下性能挑战。

  • 推理时间变长

从上面的 Attention 的计算公式可以看出,Attention 进行了点积的运算,其时间复杂度为 L(序列长度)的平方。也就是说大模型在推理的时候,输入的序列长度越长推理时间越多。所以超长上下文的大模型需要更多的推理时间,这会带来用户体验上的损失。

  • 推理显存空间变大

大模型在持续推理的过程中,需要缓存一个叫做 KV Cache 的数据快,KV Cache 的大小也与序列长度成正比。以 Llama 2 13B 大模型为例,一个 4K 长的序列大约需要 3G 的显存去缓存 KV Cache,16K 的序列则需要 12G,128K 的序列则需要 100G 显存。

超长上下文的大模型需要更多的 KV Cache 存储空间,但是 GPU 显存非常珍贵,比如 A100 也只有 40G 或 80G 显存两个版本,这对本来就比较紧张的 GPU 显存来说是一个很大的挑战。

大模型上下文扩展的思路

综上所述,扩展大模型的上下文长度,一般思路如下:

  • 首先通过对位置编码层进行改造,使其支持更长的上下文。
  • 为了取得更好的推理性能,还需要对 Attention 计算进行优化。
  • 进行微调训练,让大模型适应新的模型结构。

四、位置编码层改造扩展上下文的案例

上面我们讲到,从模型物理结构上扩展上下文长度,需要直接修改 RoPE 层,即直接扩展其 SIN 和 COS 矩阵。但是大模型都是基于大量短序列数据训练得到的。如果直接强行扩展,会导致模型困惑度提高。所谓困惑度是模型对下一个词的预测困难程度的量化指标,直观意义是大模型的输出是否能够更容易被人类所理解。

因此,我们需要更好的方法来扩展预训练模型的上下文长度,既要兼顾模型性能,又要控制困惑度。下面我们概括几种业界常用的上下文长度扩展方法。

线性位置插值法

线性插值法的思想最早来自于这篇文章,Extending Context Window of Large Language Models via Positional Interpolation[5],该方法已经被 Huggingface 的 Transformer 中 Llama 模型代码集成。

下图源自论文:

思路:通过线性缩小输入位置索引以匹配原始上下文窗口大小,而不是超出训练上下文长度进行外推,这样可以减小注意力机制中相对位置的影响,帮助模型更容易适应扩展后的上下文窗口。

效果:在从 LLaMA 7B 到 65B 模型上,通过位置插值扩展上下文窗口到 32768(4k扩展到32K),仅需微调 1000 步,就能在包括语言建模、长文档摘要撰写等任务上取得良好效果。

优点:位置插值不仅能有效扩展上下文窗口,提高模型在长上下文任务上的性能,还能在原有上下文窗口大小的任务上保持模型质量,且不需要额外的权重或修改模型架构。

缺点:需要重新训练,有时候扩充后会导致模型困惑度上升。

动态插值法(NTK-awared)

动态插值法是在位置插值法的基础上演变而来的,最早提出文章 NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning[6]。现在也被 Huggingface 的 Transformer 中 Llama 模型代码集成。

下面是 Chinese Llama 基于位置插值法与动态插值法进行的比较,数据来自 Extend Context Size Without Fine-Tuning[7]。

数据显示,相比与位置插值法,NTK 动态插值法不会显著增加大模型的困惑度。

思路:利用神经正切核 (NTK) 理论,设计非线性位置编码插值方案,改变基数而不是缩放比例,使不同位置可区分,避免线性插值的问题。

效果:与线性插值相比,在大幅扩展上下文(如8000+)时,无需微调就可以使困惑度下降极小。

优点:微调成本极低,上下文窗口可以扩展很大,困惑度变化小。

总体来说,动态插值法通过考虑模型特性,设计更优化的插值方案,能够在不增加训练成本的条件下,获得接近无损的上下文窗口扩展效果。这为进一步扩展和优化大语言模型提供了新的思路。

Yarn(NTK升级)

Yarn 扩展上下文的方法来自于文章 YaRN: Efficient Context Window Extension of Large Language Models[8]。代码参考Yarn[9]。并给出基于 Llama2 的 128K 上下文扩展。

下图源自论文[8]:

数据显示,在 128K 的 Proof-Pile 数据集上评测,Yarn-Llama-2-7b-128K/64K 模型的困惑度仍然保持良好下降。

相比于仅进行简单线性插值或动态插值的方法,Yarn 方法更全面地考虑了不同频率 RoPE 维度的作用,避免了信息损失和外推问题。这使得 Yarn 方法在不 Fine-Tuning 的情况下,以及 Fine-Tuning 的数据量很少的情况下,都能更好地扩展上下文窗口。

RoPE 中的每个维度对应着不同的正弦波频率。高频的正弦波编码了位置信息的细微变化,低频的正弦波编码了位置信息的整体趋势。

如果我们简单地进行线性插值,会把所有频率的正弦波都等比例地拉伸。这会导致两个问题:

  • 高频正弦波被过度拉伸,导致代表细微位置变化的信息丢失。这个会影响模型区分很接近的词的能力。
  • 低频正弦波被拉伸,不同位置之间的相对距离变小。这会导致模型判断近距离词的先后顺序变得困难。

为了解决这个问题,Yarn 方法对不同频率的正弦波进行不同程度的插值:

  • 对高频正弦波几乎不进行插值,保留细微位置信息。
  • 对低频正弦波进行接近线性的插值,保留位置大体信息。
  • 中频正弦波进行渐变的插值。

这样既保留了高频表示细微位置变化的信息,也保留了低频表示位置整体关系的信息,避免了简单线性插值的问题。

五、优化Attention扩展上下文的案例

上面我们提到,长上下文对大模型的正向与反向传播的性能来说是个挑战。其主要原因是 Attention(注意力)的计算复杂度比较高,为了解决这个问题,业界提出了很多优化 Attention 计算的方法。

LongLoRA方法

LongLoRA 是香港中文大学联合 MIT 提出的一种模型微调方法,其论文 LONGLORA: EFFICIENT FINE-TUNING OF LONG-CONTEXT LARGE LANGUAGE MODELS[10]。下图是论文[10]中描述的方法:

LongLoRA 提出了一种移位稀疏注意力(Shifted Sparse Attention,S2-Attn)来近似标准的自注意力。

在传统的自注意力(self-attention)中,模型需要计算输入序列中所有元素对之间的注意力权重,这在处理长序列时会导致计算复杂度呈二次方增长。S2-Attn 在训练时,将输入序列划分成若干个组,在每个组内部进行自注意力计算。为了使不同组之间有信息流通,在一半的注意力头内,向其中一个组的 Tokens 做平(Shift)操作,平移的长度为组的一半。这样就引入了不同组之间的信息交换,又不增加计算量。

S2-Attn 的设计使得大型语言模型能够在处理长序列时保持较高的性能,同时显著降低了训练和推理时的计算资源需求。

实验结果如下,用了一个 8 个 A100 的机器微调,将 Llama2 7B/13B/70B 模型分别扩展到 100K,64K, 32K 长度,而大模型的困惑度并没有明显变化。

六、业界更多扩展上下文的方法

最近 Technology Innovation Institute(TII)发表了一篇论文综述The What, Why, and How of Context Length Extension Techniques in Large Language Models -- A Detailed Survey[11],调研了业界的扩展大模型的技术,下图源自论文。

论文中介绍了更多的业界上下文扩展的方法,大致可以简单分为一下几大主要方式。

七、Kubeai大模型训练推理平台

上面我们分别讲解了 Llama 的结构,然后基于 Llama 的结构去讲解了业界最新扩展大模型上下文长度的方法与效果。

在 KubeAI 训练推理平台上,用户只需要上传数据、选择大模型,就可以完成一次训练和推理部署。如果想了解详细使用方法,可以参考我们之前发表的关于得物大模型平台的系列介绍的文章。

KubeAI 平台为用户提供了非常便捷的大模型训练和部署功能。用户无需关注底层基础设施,就可以通过简单的步骤上传数据、配置参数、选择模型,从而获得针对自己业务自定义的大模型。

八、总结与展望

本文从 Llama 大模型的结构入手,介绍了其模块结构,重点解析了 Attention 机制中的 RoPE 层。要实现 Llama 模型上下文长度的扩展,需要对应扩展 RoPE 位置编码层。但是直接扩展会导致模型困惑度上升,针对这个问题,我们介绍了业界常见的几种上下文扩展方法,包括位置查找法、动态插值法和 Yarn 方法等。

长下文推理对性能要求比较高,为此我们也介绍了一些为了提升性能而优化 Attention 的方法,比如 LongLoRA[10] 这篇论文的 S2-Atten 的方法。有兴趣的同学可以阅读相关论文了解细节。

本文通过剖析 Llama 模型结构,解析上下文扩展的关键层 RoPE,并概述各种扩展方法的原理,希望能够帮助大家对大模型上下文扩展有一个系统的了解。后续如果有机会,我们会继续分享更多大模型的核心技术,让更多人对大模型的内在机制有更深的认识。欢迎持续关注我们的内容和分享!

参考资料

*文/linggong

本文属得物技术原创,更多精彩文章请看:得物技术官网

未经得物技术许可严禁转载,否则依法追究法律责任!

相关推荐
查理零世3 小时前
【算法】数论基础——约数个数定理、约数和定理 python
python·算法·数论
汉克老师4 小时前
GESP2024年3月认证C++六级( 第三部分编程题(1)游戏)
c++·学习·算法·游戏·动态规划·gesp6级
闻缺陷则喜何志丹4 小时前
【C++图论】2685. 统计完全连通分量的数量|1769
c++·算法·力扣·图论·数量·完全·连通分量
利刃大大4 小时前
【二叉树深搜】二叉搜索树中第K小的元素 && 二叉树的所有路径
c++·算法·二叉树·深度优先·dfs
CaptainDrake4 小时前
力扣 Hot 100 题解 (js版)更新ing
javascript·算法·leetcode
一缕叶5 小时前
洛谷P9420 [蓝桥杯 2023 国 B] 子 2023 / 双子数
算法·蓝桥杯
甜甜向上呀5 小时前
【数据结构】空间复杂度
数据结构·算法
Great Bruce Young5 小时前
GPS信号生成:C/A码序列生成【MATLAB实现】
算法·matlab·自动驾驶·信息与通信·信号处理
Mryan20055 小时前
LeetCode | 不同路径
数据结构·c++·算法·leetcode
qy发大财6 小时前
验证二叉搜索树(力扣98)
数据结构·算法·leetcode·职场和发展