一、问题:为什么大模型「想太多」会变慢?
想象一下,你让ChatGPT解一道数学题,它会在脑海里「自言自语」:
"小明有5个苹果,先买了3个,现在有8个;然后吃掉2个,剩下6个。所以答案是6。"
这个过程叫思维链(CoT)------模型通过一步步推导得出答案。但问题来了:
- 步骤越长,速度越慢 : 传统CoT生成的延迟与序列长度 T呈线性关系: Latency∝T⋅(Lattn+LFFN),即生成100个token(词)比生成10个token慢10倍!
- 废话太多:像"首先""然后""所以"这些词,对解题帮助不大。
这就好比:你写作文时,如果必须把"嗯...这里应该...对吧?"之类的内心活动全写出来,交卷时间肯定来不及
二、核心思想:让模型学会「划重点」
TokenSkip的灵感很简单:不是所有token都值得生成!
1. token的重要性天差地别
- 学霸token :数字(
5、3)、公式(5+3=8)、答案(6)。 - 学渣token:连接词("所以""然后")、重复描述("我们仔细计算一下")。
举个栗子🌰 :
原始CoT:
"首先,小明有5个苹果。接着他买了3个,所以现在总共有5+3=8个。然后他吃掉2个,最后剩下6个。"
关键token :5, 3, 5+3=8, 2, 8-2=6, 答案6
冗余token :首先, 接着, 所以, 然后, 最后
2. TokenSkip的终极目标 :
保留学霸token,跳过学渣token!从而让模型生成的CoT更精炼,推理速度更快,同时保持正确率。
三、实现方法:三步让模型学会「跳步骤」
Step 1:给每个token「打分」------谁是学霸?
用一个小型模型LLMLingua-2当「判卷老师」,给CoT中的每个token打分(重要性分数)。
-
训练方式:使用GPT-4为token标注一个二分类标签(重要/不重要),以此训练出一个打分模型,模型输出的概率就是token的得分
-
如何打分:
- 高分token:对答案影响大(如数字、公式)。
- 低分token:可跳过(如连接词)。
Step 2:动态压缩------按需「删废话」
用户指定一个压缩比例γ(比如γ=0.6,保留60%的token),TokenSkip会:
- 按分数从高到低排序所有token。
- 保留前60%的高分token,剩下的直接跳过。
压缩过程演示:
- 原始CoT(10个token) :
[首先][小明][有][5][苹果][然后][买][3][所以][总数8] - 压缩后(6个token,γ=0.6) :
[小明][5][苹果][买][3][总数8]
为什么有效:
- 删掉了
首先、然后、所以等低分token。 - 保留了关键数字和动作(
买3)。
Step 3:训练模型------教会它「走捷径」
我们的最终目的是要让LLM学会自动跳token,而现在我们需要使用压缩后的COT来微调(Fine-tuning) 模型。但全量微调成本太高,TokenSkip用了LoRA:
-
数据准备:
- 收集大量原始CoT(比如数学题的解题过程)。
- 对原始训练集 D中的每个样本 (x,c,a),用step1-2生成多组压缩样本 {(x,γ,c~,a)},其中 γ从预设集合 {0.5,0.6,...,1.0}随机采样。
-
输入格式:
Input=x;EOS;γ;EOS;CompressedCoT;Answer
即问题、分隔符、压缩比率、分隔符、压缩后的思维链、答案, EOS为序列结束符, γ以数值形式嵌入问题 EOS 压缩比例0.6 EOS
- 损失函数:
L=−t=1∑∣c~∣+∣a∣logP(yt∣x,γ,y<t;θ)
其中 y=c\~;a 4. LoRA微调 :
采用LoRA(Low-Rank Adaptation),仅更新权重矩阵的低秩增量:
W′=W+ΔW=W+B⋅A,B∈Rd×r,A∈Rr×k
超参数设置:秩 r=8,缩放因子 α=16,仅调整0.2%的模型参数。
训练成本:
- 7B模型:2小时(2块3090显卡)
- 14B模型:2.5小时
(相当于刷两集《繁花》的时间~)
四、效果实测:速度翻倍,答案几乎全对
1. 评估指标
- 压缩效率 :
- 实际压缩比 : Actual Ratio=∣c∣∣c~∣
- 加速比 : Speedup=TcompressedToriginal
- 性能保留度 :
- 准确率相对下降 : ΔAcc=Accoriginal−Acccompressed
2. 理论分析
命题1 (压缩稳定性):
若令牌重要性度量 I(ci)与答案正确性强相关,则存在压缩比 γ使得:
ΔAcc≤ϵ且Speedup≈1−γ1
3. 实验结果
| 模型 | 数据集 | γ | 压缩比 | ΔAcc (%) | Speedup |
|---|---|---|---|---|---|
| Qwen2.5-14B | GSM8K | 0.6 | 40% | 0.4 | 1.67x |
| LLaMA-3.1-8B | MATH-500 | 0.7 | 30% | 3.9 | 1.43x |
- 幂律现象 :模型规模越大, ΔAcc对 γ的敏感度越低(见图5)。
- 注意力稀疏性:压缩后的CoT序列中,注意力权重更集中于关键令牌(可视化见论文图2)。
六、局限性与未来
- 未覆盖更大模型:比如Qwen-72B,推测效果会更好。
- 数学符号优化不足:公式压缩还有提升空间。
- 极端压缩会翻车:比如压缩70%,可能漏掉关键计算。
论文地址 :2502.12067v1.pdf
代码开源 :github.com/hemingkx/To...
如果有哪里没看懂,欢迎评论区提问!👇