TokenSkip:让大模型「跳步骤」推理,速度翻倍

一、问题:为什么大模型「想太多」会变慢?

想象一下,你让ChatGPT解一道数学题,它会在脑海里「自言自语」:

"小明有5个苹果,先买了3个,现在有8个;然后吃掉2个,剩下6个。所以答案是6。"

这个过程叫思维链(CoT)------模型通过一步步推导得出答案。但问题来了:

  • 步骤越长,速度越慢 : 传统CoT生成的延迟与序列长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T呈线性关系: <math xmlns="http://www.w3.org/1998/Math/MathML"> Latency ∝ T ⋅ ( L attn + L FFN ) \text{Latency} \propto T \cdot (L_{\text{attn}} + L_{\text{FFN}}) </math>Latency∝T⋅(Lattn+LFFN),即生成100个token(词)比生成10个token慢10倍!
  • 废话太多:像"首先""然后""所以"这些词,对解题帮助不大。

这就好比:你写作文时,如果必须把"嗯...这里应该...对吧?"之类的内心活动全写出来,交卷时间肯定来不及


二、核心思想:让模型学会「划重点」

TokenSkip的灵感很简单:不是所有token都值得生成

1. token的重要性天差地别

  • 学霸token :数字(53)、公式(5+3=8)、答案(6)。
  • 学渣token:连接词("所以""然后")、重复描述("我们仔细计算一下")。

举个栗子🌰

原始CoT:

"首先,小明有5个苹果。接着他买了3个,所以现在总共有5+3=8个。然后他吃掉2个,最后剩下6个。"

关键token5, 3, 5+3=8, 2, 8-2=6, 答案6
冗余token首先, 接着, 所以, 然后, 最后

2. TokenSkip的终极目标
保留学霸token,跳过学渣token!从而让模型生成的CoT更精炼,推理速度更快,同时保持正确率。


三、实现方法:三步让模型学会「跳步骤」

Step 1:给每个token「打分」------谁是学霸?

用一个小型模型LLMLingua-2当「判卷老师」,给CoT中的每个token打分(重要性分数)。

  • 训练方式:使用GPT-4为token标注一个二分类标签(重要/不重要),以此训练出一个打分模型,模型输出的概率就是token的得分

  • 如何打分

    • 高分token:对答案影响大(如数字、公式)。
    • 低分token:可跳过(如连接词)。
Step 2:动态压缩------按需「删废话」

用户指定一个压缩比例γ(比如γ=0.6,保留60%的token),TokenSkip会:

  1. 按分数从高到低排序所有token。
  2. 保留前60%的高分token,剩下的直接跳过。

压缩过程演示

  • 原始CoT(10个token)
    [首先][小明][有][5][苹果][然后][买][3][所以][总数8]
  • 压缩后(6个token,γ=0.6)
    [小明][5][苹果][买][3][总数8]

为什么有效

  • 删掉了首先然后所以等低分token。
  • 保留了关键数字和动作(买3)。
Step 3:训练模型------教会它「走捷径」

我们的最终目的是要让LLM学会自动跳token,而现在我们需要使用压缩后的COT来微调(Fine-tuning) 模型。但全量微调成本太高,TokenSkip用了LoRA

  1. 数据准备

    • 收集大量原始CoT(比如数学题的解题过程)。
    • 对原始训练集 <math xmlns="http://www.w3.org/1998/Math/MathML"> D \mathcal{D} </math>D中的每个样本 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x , c , a ) (\mathbf{x}, \mathbf{c}, \mathbf{a}) </math>(x,c,a),用step1-2生成多组压缩样本 <math xmlns="http://www.w3.org/1998/Math/MathML"> { ( x , γ , c ~ , a ) } \{(\mathbf{x}, \gamma, \tilde{\mathbf{c}}, \mathbf{a})\} </math>{(x,γ,c~,a)},其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ从预设集合 <math xmlns="http://www.w3.org/1998/Math/MathML"> { 0.5 , 0.6 , . . . , 1.0 } \{0.5, 0.6, ..., 1.0\} </math>{0.5,0.6,...,1.0}随机采样。
  2. 输入格式

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Input = [ x ; EOS ; γ ; EOS ; C o m p r e s s e d C o T ; A n s w e r ] \text{Input} = [\mathbf{x}; \text{EOS}; \gamma; \text{EOS};Compressed CoT; Answer] </math>Input=[x;EOS;γ;EOS;CompressedCoT;Answer]

即[问题、分隔符、压缩比率、分隔符、压缩后的思维链、答案], <math xmlns="http://www.w3.org/1998/Math/MathML"> EOS \text{EOS} </math>EOS为序列结束符, <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ以数值形式嵌入[问题] [EOS] 压缩比例0.6 [EOS]

  1. 损失函数
    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = − ∑ t = 1 ∣ c ~ ∣ + ∣ a ∣ log ⁡ P ( y t ∣ x , γ , y < t ; θ ) \mathcal{L} = -\sum_{t=1}^{|\tilde{\mathbf{c}}|+|\mathbf{a}|} \log P(y_t \mid \mathbf{x}, \gamma, \mathbf{y}_{<t}; \theta) </math>L=−t=1∑∣c~∣+∣a∣logP(yt∣x,γ,y<t;θ)

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> y = [ c ~ ; a ] \mathbf{y} = [\tilde{\mathbf{c}}; \mathbf{a}] </math>y=[c~;a] 4. LoRA微调

采用LoRA(Low-Rank Adaptation),仅更新权重矩阵的低秩增量:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> W ′ = W + Δ W = W + B ⋅ A , B ∈ R d × r , A ∈ R r × k W' = W + \Delta W = W + B \cdot A, \quad B \in \mathbb{R}^{d \times r}, A \in \mathbb{R}^{r \times k} </math>W′=W+ΔW=W+B⋅A,B∈Rd×r,A∈Rr×k

超参数设置:秩 <math xmlns="http://www.w3.org/1998/Math/MathML"> r = 8 r=8 </math>r=8,缩放因子 <math xmlns="http://www.w3.org/1998/Math/MathML"> α = 16 \alpha=16 </math>α=16,仅调整0.2%的模型参数。

训练成本

  • 7B模型:2小时(2块3090显卡)
  • 14B模型:2.5小时
    (相当于刷两集《繁花》的时间~)

四、效果实测:速度翻倍,答案几乎全对

1. 评估指标
  • 压缩效率
    • 实际压缩比 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Actual Ratio = ∣ c ~ ∣ ∣ c ∣ \text{Actual Ratio} = \frac{|\tilde{\mathbf{c}}|}{|\mathbf{c}|} </math>Actual Ratio=∣c∣∣c~∣
    • 加速比 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Speedup = T original T compressed \text{Speedup} = \frac{T_{\text{original}}}{T_{\text{compressed}}} </math>Speedup=TcompressedToriginal
  • 性能保留度
    • 准确率相对下降 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ Acc = Acc original − Acc compressed \Delta \text{Acc} = \text{Acc}{\text{original}} - \text{Acc}{\text{compressed}} </math>ΔAcc=Accoriginal−Acccompressed
2. 理论分析

命题1 (压缩稳定性):

若令牌重要性度量 <math xmlns="http://www.w3.org/1998/Math/MathML"> I ( c i ) I(c_i) </math>I(ci)与答案正确性强相关,则存在压缩比 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ使得:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Δ Acc ≤ ϵ 且 Speedup ≈ 1 1 − γ \Delta \text{Acc} \leq \epsilon \quad \text{且} \quad \text{Speedup} \approx \frac{1}{1-\gamma} </math>ΔAcc≤ϵ且Speedup≈1−γ1

3. 实验结果
模型 数据集 γ 压缩比 ΔAcc (%) Speedup
Qwen2.5-14B GSM8K 0.6 40% 0.4 1.67x
LLaMA-3.1-8B MATH-500 0.7 30% 3.9 1.43x
  • 幂律现象 :模型规模越大, <math xmlns="http://www.w3.org/1998/Math/MathML"> Δ Acc \Delta \text{Acc} </math>ΔAcc对 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ的敏感度越低(见图5)。
  • 注意力稀疏性:压缩后的CoT序列中,注意力权重更集中于关键令牌(可视化见论文图2)。

六、局限性与未来

  • 未覆盖更大模型:比如Qwen-72B,推测效果会更好。
  • 数学符号优化不足:公式压缩还有提升空间。
  • 极端压缩会翻车:比如压缩70%,可能漏掉关键计算。

论文地址2502.12067v1.pdf
代码开源github.com/hemingkx/To...

如果有哪里没看懂,欢迎评论区提问!👇

相关推荐
梓羽玩Python5 分钟前
这款开源AI神器,让视频创作进入"一句话生成大片"时代。
人工智能·开源·github
果冻人工智能6 分钟前
利用GPT-4.5打造一个AI猫咪表情包生成器,并从中获利润
人工智能
zozowind10 分钟前
OpenManus系列(5):3月11日更新分支分拆,MCP闪亮登场
人工智能
Baihai_IDP12 分钟前
为什么说JSON不一定是LLM结构化输出的最佳选择?
人工智能·llm·aigc
cdut_suye14 分钟前
全面剖析 Linux 进程管理与 PCB 机制
java·linux·运维·服务器·c++·人工智能·python
新加坡内哥谈技术14 分钟前
CoreWeave:从“微软专供”到OpenAI的座上宾
人工智能
@心都20 分钟前
机器学习数学基础:45.多重响应分析
人工智能·机器学习
进阶的小蜉蝣20 分钟前
[machine learning] DP(Data Parallel) vs DDP(Distributed Data Parallel)
人工智能·机器学习
YuhsiHu36 分钟前
【论文精读】ACE-Zero
人工智能·深度学习·计算机视觉·3d·机器人
声网39 分钟前
Tavus 发布对话轮次控制模型:能理解对话节奏和意图;百度推出 AI 情感陪伴应用月匣,整合 MiniMax 等模型丨日报
人工智能