论文翻译 | LazyLLM: DYNAMIC TOKEN PRUNING FOR EFFICIENT LONG CONTEXT LLM INFERENCE

摘要

基于transformer的大型语言模型的推理包括两个顺序阶段:1)预填充阶段,用于计算提示的KV缓存并生成第一个令牌;2)解码阶段,用于生成后续令牌。对于长提示,在预填充阶段必须计算所有令牌的KV缓存,这可能会显著增加生成第一个令牌所需的时间。因此,预填充阶段可能成为生成过程中的瓶颈。一个悬而未决的问题是,是否所有提示令牌对于生成第一个令牌都是必要的。为了回答这个问题,我们引入了一种新颖的方法------LazyLLM,它在预填充和解码阶段选择性地为对下一个令牌预测重要的令牌计算KV。与一次性剪枝的静态剪枝方法相反,LazyLLM允许语言模型在不同的生成步骤中动态选择上下文中的不同子集令牌,即使它们在之前的步骤中被剪枝。在各种任务的标准数据集上的广泛实验表明,LazyLLM是一种通用方法,可以与现有语言模型无缝集成,从而显著加快生成速度,而无需微调。例如,在多文档问答任务中,LazyLLM将LLama 2 7B模型的预填充阶段加速了2.34倍,同时保持了准确性。

1 引言

基于标准提示的大型语言模型(LLM)推理有两个顺序阶段:预填充和解码,如图1所示。在预填充阶段,模型计算并保存提示中每个令牌的KV缓存,并预测第一个令牌。我们将预填充阶段所需的时间称为"第一个令牌生成时间"(TTFT)。预填充阶段之后是解码阶段,模型在此阶段重复使用缓存的KV来迭代解码下一个令牌,直到满足停止条件。

在预填充阶段,提示中的所有令牌都被所有变换器层使用。对于长提示,TTFT可能会很慢,因为最先进的基于变换器的LLM既深又宽(Pope等人,2023;Kim等人,2023;Aminabadi等人,2022),计算注意力的成本与提示中的令牌数量成二次方增加。例如,Llama 2(Touvron等人,2023)拥有70亿参数,堆叠了32个变换器层,模型维度为4096。在这种情况下,TTFT需要每个后续解码步骤的挂钟时间的21倍,大约占总生成时间的23%,这是在LongBench基准测试(Bai等人,2023)上的结果。因此,优化TTFT是提高LLM推理效率的关键途径(NVIDIA,2024)。

尽管优化LLM推理是一个活跃的研究领域,许多方法(Leviathan等人,2023;Cai等人,2024;Zhang等人,2024;Bhendawade等人,2024;Li等人,2024)都集中在提高解码阶段的推理速度。然而,很少有研究关注改善TTFT。我们注意到,一些基于压缩的工作通过减小LLM的大小间接改善了TTFT(Frantar等人,2022;Sun等人,2023;Ma等人,2023)。然而,一项正交研究(Li等人,2023;Jiang等人,2023;Dao等人,2022)探讨了在静态变换器架构下如何改善TTFT。在这一研究领域中,自然会产生一个问题:提示中的所有令牌对于生成第一个令牌都是必要的吗?

LongBench基准测试 (Bai等人,2023)上的LLM分析如图2所示,输入令牌相对于第一个生成令牌的关注分数非常稀疏,这表明输入提示中的许多令牌是冗余的,可以在不影响下一个令牌预测的情况下移除。为此,我们提出了LazyLLM,这是一种新颖、简单但有效的技术,专门用于加快预填充速度。如图3所示,在每一个生成步骤中,LazyLLM选择性地为对下一个令牌预测重要的令牌计算KV,并将剩余令牌的计算"懒惰"地推迟到它们变得相关时的后续步骤。我们提出使用前一个变换器层的关注分数来衡量令牌的重要性,并沿着变换器的深度逐步剪枝令牌。与永久减少提示的提示压缩工作(Li等人,2023;Jiang等人,2023;Xu等人,2023)不同,我们的方法允许模型恢复之前剪枝的令牌,我们发现这对于保持准确性至关重要。将渐进式令牌剪枝扩展到所有生成步骤并非易事。具体来说,如果一个令牌在生成步骤t被剪枝,并在生成步骤t'(t' > t)时被恢复,那么在步骤t'中需要重新计算一些隐藏状态。为了避免这种重复计算,我们采用了一个额外的缓存机制,Aux Cache,来缓存被剪枝令牌的隐藏状态。这为恢复剪枝令牌提供了一个计算效率高的途径,并确保LazyLLM的最差运行时间永远不会慢于基线。

总结来说,LazyLLM的优点是:(1)通用性:LazyLLM可以与任何现有的基于变换器的LLM无缝集成,提高推理速度;(2)无需训练:LazyLLM不需要任何微调,可以直接集成而无需任何参数修改;(3)有效:在6个不同语言任务的16个标准数据集上的实证结果显示,LazyLLM可以在LLM的预填充和解码阶段提高推理速度。

2 相关工作

大型语言模型(LLMs)规模的增加极大地提升了它们的性能,但同时也带来了关于推理效率的挑战。如图1所示,生成式LLMs的推理包括两个不同的阶段。特别是在长上下文场景下,预填充阶段需要大量的计算来计算完整的KV缓存,导致第一个令牌生成时间(TTFT)很长。这种延迟使得用户在提交提示后需要等待数秒钟才能从代理那里得到任何响应,从而导致用户体验不佳。

高效的长上下文推理。 大量的工作(Merth等人,2024;Chen等人,2023;Beltagy等人,2020;Kitaev等人,2020)已经被提出,通过减少内存占用和总计算量来提高长上下文应用的推理效率。一些研究专注于为长上下文输入定制变换器架构。例如,(Beltagy等人,2020)引入了一种标准自注意力的替代品,将局部窗口注意力与任务驱动的全局注意力结合起来。同时,Reformer(Kitaev等人,2020)通过使用局部敏感哈希来替换点积注意力,以减少其计算复杂度。尽管上述方法可以加快长上下文推理,但它们需要显著改变模型架构并重新训练。这个缺点使得它们不适用于现有的预训练LLMs。与我们的工作更接近的是优化KV缓存的效率技术(Zhang等人,2024;Li等人,2024;Anagnostidis等人,2024;Nawrot等人,2024),通过最小化KV缓存大小和数据传输。然而,这些工作只专注于加速解码步骤,不适用于减少TTFT。

**令牌剪枝。**在句子分类任务上的先前研究(Kim等人,2022;Anagnostidis等人,2024;He等人,2021)表明,输入序列中的所有令牌(即单词)并不是成功预测所必需的。这为令牌剪枝提供了几种可能性,通过在推理过程中选择性地移除不太重要的令牌,以最小化计算需求。例如,(Kim等人,2022)提出了学习型令牌剪枝,它通过变换器层自适应地移除不重要的令牌。同时,(He等人,2021)提出了通过令牌剪枝减少基于变换器的模型(如BERT(Devlin等人,2018))的宽度方向计算。上述方法是为只需要单次处理迭代任务设计的,如文本分类。在这项工作中,我们将令牌剪枝的概念扩展到生成式LLMs。具体来说,我们的方法允许模型在每个生成步骤中动态选择不同的令牌集,这对于保持性能至关重要。此外,我们还引入了Aux Cache,以确保每个令牌在整个生成过程中最多计算一次,并确保我们方法的最差运行时间不会慢于基线。

3 LazyLLM

3.1 LLM推理的背景

生成式LLM推理包括两个阶段:预填充和解码(见图1)。在预填充阶段,模型接收长度为N的提示(令牌序列)T = {ti}^N_{i=1},其中ti表示一个令牌,N表示提示的长度,然后计算并保存每个令牌的KV缓存,并生成第一个令牌tn+1。LLMs中常用的变换器架构是由多层组成的堆叠,每层具有相同的架构,包括多头自注意力机制和一个多层感知机(MLP)。预填充的时间被称为第一个令牌生成时间(即TTFT)。预填充之后是解码步骤,模型将生成的令牌tn+1附加到输入中,然后解码接下来的令牌。解码步骤会重复执行,直到满足停止条件。尽管每个解码步骤的公式与预填充相似,但由于KV缓存的存在,其计算量显著降低。具体来说,有了来自预填充的保存KV缓存,所有之前的令牌都不需要通过模型中的任何线性层。

3.2 使用LazyLLM进行推理

所提出的LazyLLM框架的概览如图4所示。LazyLLM从完整上下文开始,逐步剪枝令牌,以逐渐减少模型末端的计算数量。注意,LazyLLM允许模型在不同的生成步骤中从上下文中选择不同的令牌子集,即使其中一些在之前的步骤中被剪枝。与一次性剪除所有令牌的静态剪枝相比,动态剪枝优化了每个生成步骤中的下一个令牌预测,这对于保持性能至关重要。

**渐进式令牌剪枝。**在此之前,令牌剪枝已成功应用于优化LLM推理(Zhang等人,2024;Li等人,2024;Adnan等人,2024;Nawrot等人,2024)。然而,这些方法需要累积预测前几个令牌的完整注意力图,以在开始剪枝之前分析提示令牌的重要性。因此,它们不适用于减少TTFT,因为它们仍然需要在预填充阶段计算所有的KV缓存。相比之下,LazyLLM仅在推理的第一个迭代(预填充步骤)开始时"懒惰地"计算对预测下一个令牌重要的令牌。在第一次迭代中剪枝令牌的一个关键挑战是确定它们的重要性。受到早期退出工作(Elhoushi等人,2024)的启发,该工作表明令牌的隐藏状态通过变换器层逐渐演变,我们在每个生成步骤中应用逐层令牌剪枝。具体来说,我们使用层的注意力图Al ∈ RH×N×N来确定输入令牌ti相对于将要预测的下一个令牌的重要性。

其中H表示注意力头的数量,N是序列长度,Ah,i,j是在第h个头上令牌tj关注令牌ti的注意力概率。

计算出令牌的置信度分数后,确定剪枝令牌的阈值值是一个挑战。具体来说,阈值可能会随着注意力分数分布在不同层和不同任务之间的变化而变化。我们通过使用top-k百分位选择策略来剪枝令牌来应对这个挑战。具体来说,如果令牌ti在第l+1层的置信分数sli小于输入令牌中的k_l百分位,则将其剪枝。一旦令牌被剪枝,它就会从所有后续层的计算中排除。换句话说,后面层中使用的令牌将是前面层的子集。

我们的研究在5.4节展示了在不同剪枝层位置和剪枝令牌数量时的性能变化。特别是,当在同一变换器层剪枝时,保留的令牌越少,模型的性能逐渐下降。我们还发现,在较后的变换器层剪枝通常比在较早层剪枝具有更好的性能,这表明较后的层对令牌剪枝不太敏感。为了在加速和准确性之间取得更好的平衡,如图4所示,我们应用渐进式剪枝,在较早的变换器层保留更多令牌,并逐渐在变换器末端减少令牌数量。

辅助缓存(Aux Cache)。在预填充阶段,没有KV缓存,每个令牌都由隐藏状态表示。因此,可以通过移除剪枝令牌的隐藏状态来实现渐进式令牌剪枝。然而,将渐进式令牌剪枝扩展到后续的解码步骤并不是一件简单的事。这是因为每个解码步骤都利用预填充阶段计算的KV缓存来计算注意力。由于LazyLLM在预填充阶段执行渐进式令牌剪枝,因此在第l层剪枝的令牌的KV(例如图4中的T4)将不会存在于第l+1层的KV缓存中。

提醒一下,LazyLLM框架允许每个生成步骤从完整的输入令牌序列中挑选不同的子集令牌,无论它们是否在前面的生成步骤中被剪枝。例如,在随后的解码步骤中,那些在第l+1层的KV缓存中不存在的剪枝令牌(例如T4)可能会被重新选择来计算注意力。在这种情况下,模型无法检索这些令牌的KV缓存。一个直观的解决方案是从变换器的开始再次传递这些令牌。然而,这将导致对相同令牌的重复计算,最终减慢整个生成过程。

为了解决这个挑战,我们除了原始的KV缓存之外,还引入了辅助缓存(Aux Cache),它存储那些在后续层的KV缓存中不存在其KV的剪枝令牌的隐藏状态(例如图4中的T4和T7),这些隐藏状态可能会在后续迭代中被检索。如图4所示,在每个解码步骤中,每个变换器层(例如第l+1层)首先检索过去令牌的KV缓存(如果存在,例如T1和T8)。对于那些在KV缓存中不存在的令牌(例如T3),我们可以直接从其前一个层的Aux Cache中检索它们的隐藏状态,而不是再次通过前面的层。引入Aux Cache确保了每个令牌在每个变换器层中最多计算一次,并确保LazyLLM的最差运行时间不会慢于基线。

4 实现细节

我们在Llama 2(Touvron等人,2023年)和XGen(Nijkamp等人,2023年)上实现了LazyLLM,并使用HuggingFace2在LongBench(Bai等人,2023年)上进行评估。在所有实验中,我们遵循LongBench的官方GitHub仓库3进行数据预处理和提示。LongBench基准测试包含不同任务中的多个数据集,每个任务可能有不同的评价指标,包括ROUGE-L、F1、准确率和编辑相似度。按照官方评估流程,我们通过计算宏平均分来对所有主要任务类别进行分类。

如前所述,提出的LazyLLM不需要任何训练。因此,LazyLLM对所有模型使用与基线完全相同的现有检查点。对于推理,我们在NVIDIA A100 GPUs上进行了所有实验。我们测量并报告基于实证墙钟时间改进的加速比。具体来说,对于TTFT加速比,我们测量从提示输入模型到模型生成第一个令牌的实证墙钟时间。对于生成加速比,我们测量从提示输入模型到模型完成生成所有输出令牌的实证墙钟时间。在开始时间测量之前,我们为每个实验增加了5次热身运行,以消除加载模型参数等噪声。

5 实验

我们使用两个大型语言模型:Llama 2 7B和XGen 7B来检验我们的方法。我们与使用相同公开发布的预训练检查点的基线进行比较,而不进行任何额外的训练。我们使用LongBench进行实验,LongBench是一个用于长内容理解的多任务基准测试。LongBench包含16个数据集,涵盖6个任务,包括单文档QA、多文档QA、摘要、少样本学习、合成任务和代码补全。

对于评价指标,我们主要评估每种方法在TTFT加速比与准确性权衡中的有效性和效率。遵循LongBench的评价标准,准确性(得分)表示每个任务中跨数据集的宏平均分。TTFT加速比测量相对于基线生成第一个令牌的墙钟时间改进。在分析中,我们还评估了我们的方法对提示令牌计算百分比和生成加速比的影响。提示令牌计算百分比测量在生成结束时计算出的提示令牌的累计百分比,这表明了总计算量的节省。生成加速比测量相对于基线完成整个生成过程的墙钟时间变化。

5.1 结果

表1给出了在LazyLLM、标准LLM和其他基线之间TTFT加速与准确性的比较。在表中,"基线"指的是标准的LLM推断。"随机令牌掉落"基线基于(Yao等人,2022),即在将提示令牌提供给llm之前随机修剪它们。我们报告了5次运行的"随机令牌掉落"基线的平均指标。我们的"静态令牌修剪"基线根据预填充阶段前几个变压器层的注意力得分立即修剪输入令牌。我们还比较了提示压缩方法(Li et al ., 2023),该方法使用llm在输入上下文中修剪冗余。

表1显示,在多个任务中,LazyLLM始终如一地实现了更好的TTFT加速,精度下降可以忽略不计。值得注意的是,运行llm压缩提示符的开销在计算上非常昂贵。尽管对减少的提示的推断更快,但是"提示压缩"基线的实际TTFT要比基线长。

5.2 TTFT加速vs.准确性

LazyLLM的推理效率通过三个参数进行控制:1)剪枝层的数量,2)这些剪枝层的位置,以及3)在这些层中剪枝的令牌数量。增加剪枝层的数量和剪枝更多的令牌可以通过处理更少的令牌来优化计算,而在较早的层中剪枝令牌可以为后续层节省计算。调整这些因素将带来更多的整体计算减少,并提供更好的TTFT加速比。然而,副作用是过度剪枝令牌可能会导致信息丢失,最终导致性能下降。同样,基线的TTFT加速比和准确性也会随着不同超参数的变化而变化。

我们在图5中比较了不同超参数下的TTFT加速比与准确性。可视化显示,在没有进行任何训练的情况下,提出的LazyLLM在相同的TTFT加速比下比基线更好地保持了准确性。例如,我们的方法可以在多文档问答任务中提供2.34倍的TTFT加速比,而性能损失可以忽略不计(≤ 1%)。通过控制剪枝参数,LazyLLM在准确性和推理速度之间提供了相对于基线方法的良好权衡。例如,LazyLLM可以在多文档问答任务中实现3.0倍的TTFT加速比,而准确性下降≤ 10%。另一方面,基线方法在类似的TTFT加速比下准确性显著下降。请注意,提示压缩方法由于压缩开销而未能提高TTFT。

5.3 对整体生成速度的影响

为了评估所提出的方法对整个生成过程的影响,我们还在表2中分析了提示令牌计算的百分比和生成加速。我们可以发现,LazyLLM的Token Computed的%小于100%,这表明在生成结束时,并不是提示符中的所有Token都被LazyLLM选中,尽管理论上模型可以使用所有Token。FFN层的计算量呈线性增长,而注意力层的计算量随着Token计算的百分比呈二次增长。Token Computed的百分比越低,表明LazyLLM减少了总计算量,从而为跨不同任务的整体生成过程提供了额外的加速

5.4 各层掉落率

在本节中,我们分析了剪枝层位置和剪枝令牌数量对整体生成速度的影响。特别是,我们报告了一系列使用LazyLLM简化版本的实验,该版本仅在变换器内进行一次令牌剪枝。每次试验中,我们将剪枝层放置在变换器堆栈的不同层级,并应用不同的剪枝比例。我们对Llama 2和XGen都进行了实验,并在图6中可视化了结果。

结果显示两个模型呈现出相似的趋势。正如预期的那样,当在同一变换器层进行剪枝时,随着保留的令牌数量减少,模型的性能逐渐下降。此外,与在较早的层进行剪枝相比,在较后的变换器层进行剪枝始终能获得更好的性能,这表明较后的层对令牌剪枝的敏感性较低。基于这些观察,我们在第3.2节提出了渐进式令牌剪枝,该策略在后期层中剪枝更多令牌,而在早期层中保留更多令牌,从而优化了效率与性能保持之间的平衡。

5.5 渐进式 KV 增长

在本节中,我们将使用令牌修剪逻辑描述模型的内部结构。具体来说,我们试图了解提示令牌的哪些部分是累积使用的,哪些部分是不使用的。这种"累积令牌使用量"可以等效地定义为每个给定步骤的KV缓存大小。图7显示了LazyLLM每个阶段的累积提示令牌使用数量。

我们的分析支持这样一个假设,即许多令牌从未被模型选中(尽管理论上模型可以在提示符中使用所有令牌)。由于该模型保留了任务的准确性,我们可以得出结论,该模型有效地删除了不影响输出质量的令牌。

6 结论

在这项工作中,我们提出了一种新颖的LazyLLM技术,用于高效的LLM推理,特别是在长上下文场景下。LazyLLM选择性地计算对下一个令牌预测重要的令牌的KV,并"懒惰地"将剩余令牌的计算推迟到稍后的步骤,当它们变得相关时。我们仔细检查了LazyLLM在各种任务上的表现,我们观察到所提出的方法有效地减少了TTFT,而性能损失可以忽略不计。值得注意的是,我们的方法可以与现有的基于 transformer的LLM无缝集成,以提高其推理速度,而无需任何微调。

相关推荐
**之火几秒前
(五)机器学习 - 数据分布
人工智能·机器学习
martian6652 分钟前
人工智能机器学习基本概念详解
人工智能·机器学习
coldstarry20 分钟前
sheng的学习笔记-AI-自然语言处理(NLP),机器翻译,情感分类,词嵌入
人工智能·深度学习·自然语言处理·机器翻译
小雄abc27 分钟前
决定系数R2 浅谈三 : 决定系数R2与相关系数r的关系、决定系数R2是否等于相关系数r的平方
经验分享·笔记·深度学习·算法·机器学习·学习方法·论文笔记
uyeonashi1 小时前
【C++】刷题强训(day14)--乒乓球匡、组队竞赛、删除相邻数字的最大分数
开发语言·c++·算法·哈希算法
机器学习之心1 小时前
一区正弦余弦算法!SCA-SVM正弦余弦算法优化支持向量机多特征分类预测
算法·支持向量机·分类·sca-svm·正弦余弦算法优化
a栋栋栋2 小时前
刷算法心得
算法
知来者逆2 小时前
Layer-Condensed KV——利用跨层注意(CLA)减少 KV 缓存中的内存保持 Transformer 1B 和 3B 参数模型的准确性
人工智能·深度学习·机器学习·transformer
華華3552 小时前
读程序题...
开发语言·c++·算法
宸码3 小时前
【机器学习】手写数字识别的最优解:CNN+Softmax、Sigmoid与SVM的对比实战
人工智能·python·神经网络·算法·机器学习·支持向量机·cnn