POP Prefill-Only Pruning for Efficient Large Model Inference

POP: Prefill-Only Pruning for Efficient Large Model Inference

Authors: Junhui He, Zhihui Fu, Jun Wang, Qingan Li

Deep-Dive Summary:

POP: 面向高效大模型推理的 Prefill-Only 剪枝 (Prefill-Only Pruning)

摘要 (Abstract)

大型语言模型 (LLMs) 和视觉语言模型 (VLMs) 展现了卓越的能力,但其巨大的计算开销阻碍了部署。现有的硬件友好型结构化剪枝方法往往会导致显著的精度下降。本文指出,这种失效源于忽略了预填充 (prefill)解码 (decode) 阶段之间不对称作用的"阶段不可知"剪枝策略。通过引入虚拟门控 (virtual gate) 机制进行重要性分析,研究发现深层网络对于下文预测(解码)至关重要,但对于上下文编码(预填充)则存在大量冗余。

基于此,本文提出了 Prefill-Only Pruning (POP) :在计算密集型的预填充阶段安全地省略深层网络,而在敏感的解码阶段保留完整模型。为了实现阶段间的无缝切换,POP 引入了独立的 Key-Value (KV) 投影 以维持缓存完整性,并采用边界处理策略 确保首个生成 token 的准确性。实验表明,POP 在 Llama-3.1、Qwen3-VL 和 Gemma-3 上实现了高达 1.37 × 1.37 \times 1.37× 的预填充加速,且性能损失极小。

1. 引言 (Introduction)

大模型的推理过程由两个截然不同的阶段组成:

  1. 预填充 (Prefill):旨在将输入历史编码到 KV 缓存中,为后续生成提供上下文。
  2. 解码 (Decode):具有双重作用,既要将当前 token 编码进缓存,又要建模下一个 token 的概率分布。

现有的结构化剪枝(如 SliceGPT, ShortGPT)采用"一刀切"的方案,忽略了这两个阶段的功能不对称性。本文提出的 POP 策略通过在预填充阶段剪枝深层网络来加速推理,同时在解码阶段恢复完整能力。

2. 预备知识 (Preliminary)

在标准的 Transformer 推理中,第 l l l 层的计算可表示为:
y l ≔ x l + A t t n ( x l , K l p a s t , V l p a s t ) y_{l} \coloneqq x_{l} + \mathrm{Attn}(x_{l}, K_{l}^{\mathrm{past}}, V_{l}^{\mathrm{past}}) yl:=xl+Attn(xl,Klpast,Vlpast)
x l + 1 ≔ y l + F F N ( y l ) x_{l+1} \coloneqq y_{l} + \mathrm{FFN}(y_{l}) xl+1:=yl+FFN(yl)

KV 缓存的生成过程如下:
k l n e w ≔ R o P E ( L N ( x l ) W l K ) k_{l}^{\mathrm{new}} \coloneqq \mathrm{RoPE}(\mathrm{LN}(x_{l})W_{l}^{K}) klnew:=RoPE(LN(xl)WlK)
v l n e w ≔ L N ( x l ) W l V v_{l}^{\mathrm{new}} \coloneqq \mathrm{LN}(x_{l})W_{l}^{V} vlnew:=LN(xl)WlV

传统的层剪枝(Layer Pruning)会将整层替换为恒等映射(Identity mapping): x ^ l + 1 ≔ x l \hat{x}{l+1} \coloneqq x{l} x^l+1:=xl。

3. 方法 (Method)

3.1 使用虚拟门控估计层重要性

为了量化每一层对模型性能的贡献,本文引入了虚拟标量参数 g l g_l gl:
y ^ l ≔ x l + A t t n ( ⋅ ) ⊙ g l \hat{y}{l} \coloneqq x{l} + \mathrm{Attn}(\cdot) \odot g_{l} y^l:=xl+Attn(⋅)⊙gl
x ^ l + 1 ≔ y ^ l + F F N ( y ^ l ) ⊙ g l \hat{x}{l+1} \coloneqq \hat{y}{l} + \mathrm{FFN}(\hat{y}{l}) \odot g{l} x^l+1:=y^l+FFN(y^l)⊙gl

当 g l = 1 g_l = 1 gl=1 时为原模型,当 g l = 0 g_l = 0 gl=0 时该层被剪枝。通过二阶泰勒展开并利用 Fisher 信息简化计算,得出重要性得分估计公式:
I ^ l = E [ ( ∂ L ∂ g l ) 2 ] \hat{I}{l} = \mathbb{E}\left[\left(\frac{\partial\mathcal{L}}{\partial g{l}}\right)^{2}\right] I^l=E[(∂gl∂L)2]

3.2 阶段感知的关键分析

通过对预填充和解码阶段分别计算重要性得分,研究发现(如图 1 所示):

  • 阶段间的不对称性:解码阶段(橙线)对剪枝的敏感度远高于预填充阶段(蓝线)。
  • 深层的冗余性:在预填充阶段,深层网络的重要性得分接近于零,表明其在编码上下文时是冗余的。而在解码阶段,深层网络对预测下一个 token 至关重要。

(a) Llama-3.1-8B-Instruct, WizardLM-V2-196K

3.3 Prefill-Only Pruning (POP) 实现

基于上述发现,POP 采取以下核心措施:

  1. 静态剪枝策略:在预填充阶段剪掉最后 1/3 的层。
  2. 独立 KV 投影 :为了确保解码阶段有完整的 KV 缓存,即便在预填充阶段跳过了 Attention 和 FFN 的重度计算,仍执行轻量级的 KV 投影操作( W K , W V W^K, W^V WK,WV),此开销通常小于全层计算的 5 % 5\% 5%。
  3. 边界处理 :将输入序列的最后一个 token x N x_N xN 视为解码阶段的第一步,使用完整模型处理,以确保第一个生成的 token 准确。

4. 实验 (Experiments)

本文在 Llama-3.1、Qwen3-VL 和 Gemma-3 等多个模型上进行了评估,涵盖常识推理、数学代码、长文本 QA 和多模态任务。

4.2 下游任务准确率

实验结果如表 1 所示:

Table 1: Accuracy comparison across different models and tasks. "Avg" denotes the average score across all tasks. The pruning ratios are indicated in parentheses. † \dagger † denotes likelihood-based tasks; ‡ \ddagger ‡ denotes open-ended generation tasks. Bold indicates the best results for all structured pruning methods. Italic indicates unstructured pruning methods (Wanda).

Method Common Sense † \dagger † Math & Code † \dagger † Long Context QA † \dagger † Multi-Modal ‡ \ddagger ‡ Avg
MMLU HellaSwag WinoG PIQA GSM8K HumanEval MultiFieldQA HotpotQA MMMU RealWorldQA TextVQA ScreenSpot
Llama-3.1-8B-Instruct
Full Model 68.33 79.50 74.40 81.12 79.68 68.29 54.57 55.66 - - - - 70.19
Wanda (30%) 65.87 78.96 74.59 80.74 76.42 65.84 52.80 53.03 - - - - 68.53
SliceGPT (25%) 34.97 51.19 66.54 63.87 0.91 0.00 12.35 8.71 - - - - 29.82
ShortGPT (25%) 65.80 61.93 69.77 70.51 0.38 0.00 6.80 3.81 - - - - 34.88
POP (31.25%) 67.43 78.29 73.40 80.36 77.26 64.63 52.88 53.48 - - - - 68.47
Qwen3-VL-8B-Instruct
Full Model 74.95 76.60 73.72 79.92 81.50 92.07 53.53 65.49 51.33 69.67 82.24 87.03 74.00
Wanda (30%) 73.78 75.22 72.45 80.47 83.32 90.85 52.87 63.19 52.00 67.45 81.08 85.22 73.16
SliceGPT (25%) 39.16 44.50 57.93 67.25 13.95 17.68 40.76 38.33 28.00 32.55 13.54 0.24 32.82
ShortGPT (25%) 33.85 48.24 61.56 64.96 0.83 0.00 21.44 16.37 32.22 53.07 33.69 0.86 30.59
POP (33.3%) 75.05 76.44 73.88 80.14 80.21 89.63 52.34 63.13 50.67 69.28 80.73 86.40 73.16
Gemma-3-12B-It
Full Model 71.46 81.96 74.35 78.07 73.62 82.32 55.90 59.62 46.78 54.64 67.02 11.08 63.07
Wanda (30%) 69.70 80.82 73.64 77.42 75.13 83.54 55.28 58.78 45.89 55.29 64.67 10.38 62.55
SliceGPT (25%) 22.95 34.12 54.14 55.93 1.67 0.00 10.83 4.18 25.56 5.23 2.59 0.24 18.12
ShortGPT (25%) 23.81 30.32 48.70 53.70 0.91 0.00 1.58 0.34 25.00 0.39 0.00 0.24 15.42
POP (33.3%) 71.37 81.96 74.59 79.76 73.16 81.10 57.33 59.11 46.78 55.42 63.71 11.08 62.95

主要结论:

  1. 现有结构化剪枝导致生成能力崩溃:SliceGPT 和 ShortGPT 在开放式生成任务(如 GSM8K, HumanEval)中几乎完全失效,精度接近零。
  2. POP 保持了极高的准确率 :即使剪枝比例更高(约 33%),POP 在所有模型和任务中均表现稳健,性能接近完整模型(如 Llama-3.1 上的 GSM8K 保持了 97 % 97\% 97% 的性能)。
  3. 超越非结构化剪枝的实用性 :POP 的准确率与非结构化剪枝方法 Wanda 相当,但由于是结构化移除层,它在标准硬件上更易实现加速。
    以下是该学术论文相关部分的中文摘要:

4.3 推理加速

我们在 NVIDIA A100 GPU 上通过测量首字生成时间(Time-to-First-Token, TTFT)评估了 POP 的推理加速效果。实验设置 Batch Size 为 8,文本输入长度从 32 到 2048 token,图像分辨率从 640 × 480 640 \times 480 640×480 到 2560 × 1440 2560 \times 1440 2560×1440。实验结果如表 2 所示。

非结构化剪枝的硬件限制 :虽然 Wanda 在下游任务上实现了高精度,但在使用密集内核的 A100 GPU 上没有产生实际的墙钟时间加速( 1.0 × 1.0 \times 1.0×)。这证实了非结构化剪枝虽然理论上减少了 FLOPs,但需要专门的硬件和稀疏内核才能实现效率收益。

文本输入序列长度的影响 :对于文本输入,POP 的效率收益高度依赖于输入序列长度。在短上下文长度(如 32 token)下,POP 的加速效果有限(Llama-3.1 为 1.22 × 1.22 \times 1.22×,Gemma-3 为 1.02 × 1.02 \times 1.02×)。这主要是由于边界处理策略导致的:短输入的前缀处理是内存受限(memory-bound)过程,受模型权重访问主导。由于处理最后一个输入 token 需要使用全量模型,POP 无法减少这些内存访问开销,从而限制了性能增益。

然而,随着序列长度增加,由剪枝模型处理的前 N − 1 N - 1 N−1 个 token 的计算成本成为 TTFT 的主导因素。因此,POP 表现出显著加速。在输入长度为 2048 时,POP 在 Llama-3.1 上实现了 1.36 × 1.36 \times 1.36× 的加速,在 Gemma-3 上实现了 1.37 × 1.37 \times 1.37× 的加速,优于 SliceGPT 和 ShortGPT。

多模态任务效率 :对于视觉输入,POP 提供了 1.16 × 1.16 \times 1.16× 到 1.19 × 1.19 \times 1.19× 的加速,在所有图像分辨率下均持续超越 SliceGPT 和 ShortGPT,同时提供更好的精度。

总体而言,实验结果验证了 POP 是一种实用的"即插即用"加速方案,无需模型重训或专用硬件,在预填充延迟至关重要的长上下文和高分辨率多模态处理中具有显著优势。

表 2:不同模型和输入长度下的预填充加速对比。所有实验均在 Batch Size 为 8 的条件下进行。数值代表相对于全量模型( 1.0 × 1.0 \times 1.0×)的加速比。

|--------------------------|-----------|------------|-------------|-------------|
| Llama-3.1-8B-Instruct |||||
| Method \\ Input Length | 32 | 128 | 512 | 2048 |
| Wanda | 1.00 | 1.00 | 1.00 | 1.00 |
| SliceGPT | 1.22 | 1.31 | 1.29 | 1.31 |
| ShortGPT | 1.30 | 1.29 | 1.31 | 1.30 |
| POP | 1.22 | 1.27 | 1.34 | 1.36 |
| Gemma-3-12B-It |||||
| Method \\ Input Length | 32 | 128 | 512 | 2048 |
| Wanda | 1.00 | 1.00 | 1.00 | 1.00 |
| SliceGPT | 1.10 | 1.29 | 1.27 | 1.29 |
| ShortGPT | 1.25 | 1.29 | 1.31 | 1.31 |
| POP | 1.02 | 1.27 | 1.34 | 1.37 |
| Qwen3-VL-8B-Instruct |||||
| Method \\ Resolution | 640 × 480 | 1280 × 720 | 1920 × 1080 | 2560 × 1440 |
| Wanda | 1.00 | 1.00 | 1.00 | 1.00 |
| SliceGPT | 1.14 | 1.16 | 1.15 | 1.14 |
| ShortGPT | 1.18 | 1.17 | 1.15 | 1.13 |
| POP | 1.19 | 1.19 | 1.18 | 1.16 |

4.4 消融实验

我们使用 Qwen3-VL 进行了消融实验,以验证 POP 的设计选择和参数敏感性。

设计选择的有效性:我们验证了三个关键组件的必要性:(1)针对深层进行剪枝;(2)独立的 KV 投影;(3)边界处理。结果如表 3 所示。

剪枝率的敏感性 :我们研究了剪枝率从 20 % 20\% 20% 到 60 % 60\% 60% 之间的权衡。结果如表 4 所示。在较低剪枝率(20%-25%)下,模型精度甚至略高于全量模型,这可能是轻微剪枝起到了正则化作用,过滤了深层的噪声。我们的默认比例 33% 在精度损失极小的情况下实现了显著加速( 1.37 × 1.37 \times 1.37×)。当剪枝率超过 50 % 50\% 50% 时,性能会大幅下降,表明过度剪枝会损害模型编码复杂上下文的能力。

表 3:设计选择的消融实验。我们在 Qwen3-VL-8B-Instruct 上对比了不同的层选择策略和组件移除情况。"w/o Indep. KV"表示移除剪枝层的独立 KV 投影。"w/o Boundary"表示移除最后一个输入 token 的边界处理。

|--------------------------|-------|----------|
| Method Variants | GSM8K | HotpotQA |
| Full Model | 81.50 | 65.49 |
| POP | 80.21 | 63.13 |
| Layer Selection Strategy | | |
| Shallow Pruning | 0.15 | 0.00 |
| Interleaved Pruning | 56.48 | 6.81 |
| Component Necessity | | |
| w/o Indep. KV Proj. | 2.05 | 1.18 |
| w/o Boundary Handling | 77.33 | 11.45 |

表 4:剪枝率的影响。Qwen3-VL-8B-Instruct 在不同剪枝率下的性能与加速权衡。

|-----------------|---------|-------|----------|
| Pruning Ratio | Speedup | GSM8K | HotpotQA |
| 0% (Full Model) | 1.00× | 81.50 | 65.49 |
| 20% | 1.19× | 83.09 | 65.46 |
| 25% | 1.25× | 82.34 | 65.81 |
| 33% (Default) | 1.37× | 80.21 | 63.13 |
| 40% | 1.46× | 80.82 | 61.69 |
| 50% | 1.67× | 78.54 | 34.69 |
| 60% | 1.96× | 38.51 | 5.45 |

5 相关工作

  • Token 剪枝与压缩:如 LLMLingua、PyramidInfer 和 FastV 等方法,通过减少序列长度来加速推理。POP 可以与这些方法结合使用。
  • 稀疏注意力:针对计算受限的预填充阶段(如 MInference)或内存受限的解码阶段(如 Quest)优化注意力机制。POP 与这些方法在提升效率方面具有互补性。

6 结论

我们发现并利用了预填充和解码阶段对模型剪枝的不对称敏感性。分析表明,虽然深层对于生成(解码)必不可少,但对上下文编码(预填充)贡献较小。基于此,我们提出了 POP,通过在预填充阶段剪枝深层而在解码阶段保留全量模型来加速推理。POP 实现了高达 1.37 × 1.37 \times 1.37× 的预填充加速,同时保持了与全量模型相当的精度。

7 局限性

首先,POP 在解码阶段仍需加载全量模型权重,因此无法减少峰值显存(VRAM)占用,更适合计算受限而非容量受限的场景。其次,目前的实现基于 Transformers 库的单体推理管线,未来可进一步集成到预填充与解码分离的分布式推理系统中。

Original Abstract: Large Language Models (LLMs) and Vision-Language Models (VLMs) have demonstrated remarkable capabilities. However, their deployment is hindered by significant computational costs. Existing structured pruning methods, while hardware-efficient, often suffer from significant accuracy degradation. In this paper, we argue that this failure stems from a stage-agnostic pruning approach that overlooks the asymmetric roles between the prefill and decode stages. By introducing a virtual gate mechanism, our importance analysis reveals that deep layers are critical for next-token prediction (decode) but largely redundant for context encoding (prefill). Leveraging this insight, we propose Prefill-Only Pruning (POP), a stage-aware inference strategy that safely omits deep layers during the computationally intensive prefill stage while retaining the full model for the sensitive decode stage. To enable the transition between stages, we introduce independent Key-Value (KV) projections to maintain cache integrity, and a boundary handling strategy to ensure the accuracy of the first generated token. Extensive experiments on Llama-3.1, Qwen3-VL, and Gemma-3 across diverse modalities demonstrate that POP achieves up to 1.37 × \times × speedup in prefill latency with minimal performance loss, effectively overcoming the accuracy-efficiency trade-off limitations of existing structured pruning methods.

PDF Link: 2602.03295v1

部分平台可能图片显示异常,请以我的博客内容为准

相关推荐
近津薪荼2 小时前
dfs专题——二叉树的深搜3(二叉树剪枝)
c++·学习·算法·深度优先
拼好饭和她皆失2 小时前
数学知识:约数的详细解析
算法·数论
伯明翰java2 小时前
排序算法(1)
算法·排序算法
啊阿狸不会拉杆2 小时前
《机器学习导论》第 2 章-监督学习
数据结构·人工智能·python·学习·算法·机器学习·监督学习
乌萨奇也要立志学C++2 小时前
【洛谷】记忆化搜索 原理剖析与经典例题详解
算法·深度优先
Code920072 小时前
洛谷P3514 [POI 2011] LIZ-Lollipop(思维题)
算法
colus_SEU2 小时前
【论文精读】Instance-Dependent Partial Label Learning
人工智能·深度学习·机器学习·pll·部分标签学习
m0_706653232 小时前
C++中的解释器模式
开发语言·c++·算法
We་ct2 小时前
LeetCode 202. 快乐数:题解+思路拆解
前端·算法·leetcode·typescript