可学习破坏策略:实现大语言模型二倍推理加速的统一自洽框架

摘要

自回归生成是当前大语言模型(LLM)推理延迟的根本瓶颈。基于 Jacobi 迭代的解码方法可将自回归过程转化为并行修正,理论上能将生成步数从序列长度 nnn 压缩至约 n/2n/2n/2,实现近 2 倍加速。现有工作(如 CLLMs)通过一致性训练让模型学会从任意含噪状态直接映射到完整序列,从而加速收敛。然而,这些方法中施加于训练数据的破坏策略 (mask/噪声类型、位置、比例)均由手工规则设计,无法针对模型内部能力自适应调整。本文提出一种完全自适应的训练框架------Self-Masked Consistency Learning (SMCL) ,将"如何破坏"本身建模为同一个 LLM 的可学习任务。我们设计了一个三合一的 LLM 系统,通过前缀指令切换三种行为模式:Mask 策略生成全局信息补全非全信息自回归。Mask 策略通过 Gumbel-Softmax 松弛实现端到端可微训练,以最大化补全任务的学习进度为优化目标,形成一个自洽的课程学习闭环。训练完成后,模型在推理时仅使用补全与自回归能力,以 Jacobi 迭代解码实现稳定的约 2 倍生成加速,且无需任何手工掩码规则。本文详细阐述了方法设计、训练范式与推理流程,并讨论了该框架的潜在优势与局限。


1. 引言

大语言模型在几乎所有自然语言处理任务上都展现出惊人能力,但其推理效率仍受制于自回归解码范式:生成一个长度为 nnn 的序列,需要执行 nnn 次串行前向传播。当 nnn 较大时,即便硬件算力充裕,串行依赖性也使得延迟居高不下,严重阻碍了实时交互与大规模部署。

近期,Jacobi 解码 (或称并行解码)路线展现了巨大的加速潜力。它将一次自回归生成重构为一组固定点迭代问题:从一个初始猜测序列(如全 [MASK])开始,每一步将整个序列输入模型,并行预测所有位置的下一 token,再根据收敛情况固定部分位置,重复迭代直至所有位置稳定。若平均每步能收敛超过 1 个 token,则总步数便小于 nnn,实现加速。然而,标准 LLM 由于没有经历过这类"并行猜测+修正"的训练,其 Jacobi 迭代收敛极慢,甚至可能出现不收敛或退化。

一致性大语言模型(CLLMs) 通过精巧的微调解决了这一问题。它让模型在一个前向传播中,直接从"被随机扰动过的序列"映射到最终的完整序列,赋予模型并行补全 的能力。同时,它也保持了一定的自回归修正能力。经过一致性微调的模型在 Jacobi 迭代中平均每步可收敛约 2 个 token,将解码步数从 nnn 降至约 n/2n/2n/2,实现近 2 倍加速。

然而,CLLMs 以及所有同类方法的掩码策略(Mask Policy) 都是手工设计的:例如"随机替换固定比例的 token 为错误 token"或"基于雅可比轨迹的启发式选择"。这些手工规则忽略了不同序列、不同模型状态下"何种破坏最能促进学习"的差异,可能产生过于简单或过于困难的训练样本,导致学习效率未臻最优。

本文提出的核心问题是:能否让"如何 mask"也成为一个由同一个 LLM 执行的学习任务? 如果我们能训练出一个 Mask 策略模型,动态地针对每一个输入序列生成破坏方案,使得补全任务始终处于"难而可解"的学习区,那么模型对并行预测和修正的掌握将更快速、更鲁棒,最终实现更稳定的加速。

为此,我们设计了自掩码一致性学习 (SMCL) 框架。整体系统由同一个 LLM 参数化,通过特殊的前缀令牌承担三种角色:破坏者 (生成 mask 动作)、补全者 (并行预测完整序列)与自回归修正者(逐步生成正确 token)。三种角色联合训练,破坏者由补全损失通过 Gumbel-Softmax 技术直接优化,无需强化学习的复杂奖励工程。训练结束后,推理时只保留后两种能力,以标准的 Jacobi 迭代解码运行,完全不依赖 Mask 策略,但受益于训练时自适应破坏带来的高性能,加速比稳定在 2 倍左右。

接下来的部分按如下组织:第 2 节介绍背景;第 3 节详述我们的方法,包括双任务设计、Mask 策略作为 LLM 任务、联合训练与推理;第 4 节给出实验设计思路与预期结果;第 5 节讨论优势、局限及未来工作;第 6 节总结全文。


2. 背景与动机

2.1 Jacobi 解码与一致性训练

给定前缀 ppp,标准自回归生成需要依序采样 x1,x2,...,xnx_1, x_2, \dots, x_nx1,x2,...,xn。Jacobi 解码将其视作求解固定点方程 y=fθ(y;p)y = f_\theta(y; p)y=fθ(y;p),其中 fθf_\thetafθ 是模型在给定全序列输入时输出的所有位置预测。迭代过程如下:

  • 初始化 y(0)=\[MASK,...,MASK]y^{(0)} = \\text{\[MASK}, \dots, \text{MASK}]y(0)=\[MASK,...,MASK](或随机猜测)。
  • 第 kkk 步:计算 y^=arg⁡max⁡fθ(y(k−1);p)\hat{y} = \arg\max f_\theta(y^{(k-1)}; p)y^=argmaxfθ(y(k−1);p),找到那些 yt(k−1)=y^ty^{(k-1)}_t = \hat{y}_tyt(k−1)=y^t 且置信度高的位置,将其固定,未固定的位置用 y^t\hat{y}_ty^t 更新,得到 y(k)y^{(k)}y(k)。
  • 直到所有位置固定。
    若平均每步固定的 token 数大于 1,则总迭代步数 K<nK < nK<n,加速比为 n/Kn/Kn/K。

要使 Jacobi 迭代快速收敛,模型必须拥有两种能力:(1) 并行预测 ------能从一个充满噪声的序列中一次性正确猜测大量 token;(2) 从非全信息中自回归修正------能在部分草稿已知、部分错乱的上下文中依次修正错误。

CLLMs 通过在微调阶段模拟 Jacobi 轨迹,构造"被污染的序列 → 正确序列"的映射对,并联合标准 next-token 损失进行训练,成功赋予了模型上述能力。但其中的"污染"方式(如随机替换若干 token)是固定且全局一致的。

2.2 动机:破坏策略应该由模型自己决定

一个理想的训练样本应当让模型的预测损失既不过大(完全不可学),也不过小(无需学习)。对于生成任务而言,破坏的粒度、位置和类型都应该随着模型的能力动态调整:模型已经熟练掌握的位置应施以更强破坏,迫使其在更高不确定性下修复;困难位置应适当保留线索,避免陷入噪声过强。这种 自适应课程 往往需要针对每个序列单独设计。手工规则无法感知模型内部的置信度和当前技能边界,因此必然存在大量"非最优"训练信号。

如果我们将"如何破坏"视为一个序列决策问题,就可以让同一个 LLM 同时学习破坏策略和生成修复能力。这不仅能省去手工设计,更使得破坏策略能够与生成能力在训练中协同进化。


3. 方法:自掩码一致性学习 (SMCL)

SMCL 基于一个预训练的自回归 LLM,通过添加特殊令牌和联合训练目标,将其扩展为三功能统一的系统。下面我们分别描述三种任务模式以及如何统一训练和推理。

3.1 双任务微调:补全与自回归

为了构建 Jacobi 解码所需的基础能力,我们首先定义两种核心模式,它们共享全部参数,仅通过输入前缀区分。

模式 I:全局信息补全(Fill Mode)

  • 输入:[FILL] 前缀 + 破坏后的序列 x~=(x~1,...,x~n)\tilde{x} = (\tilde{x}_1, \dots, \tilde{x}_n)x~=(x~1,...,x~n),其中部分 token 被替换为 [MASK] 或错误 token。
  • 目标:一次性预测整个完整序列 x=(x1,...,xn)x = (x_1, \dots, x_n)x=(x1,...,xn)。
  • 损失:对所有位置的交叉熵损失 Lfill=−∑t=1nlog⁡pθ(xt∣x~,FILL)\mathcal{L}{\text{fill}} = -\sum{t=1}^n \log p_\theta(x_t \mid \tilde{x}, \texttt{FILL})Lfill=−∑t=1nlogpθ(xt∣x~,FILL)。
    该任务强制模型发展出从全局噪声中并行推断多个正确 token 的能力。

模式 II:非全信息自回归(AR Mode)

  • 输入:[AR] 前缀 + 破坏序列 x~\tilde{x}x~(构造方式允许前缀部分正确,后续夹杂正确草稿与噪声)。
  • 目标:标准的自回归 next-token 预测,即对于位置 ttt,模型看到 x~<t\tilde{x}_{<t}x~<t 以及可能存在的未来噪声 token(通过注意力掩码适当控制可见性),预测真实 xtx_txt。
  • 损失:Lar=−∑t=1nlog⁡pθ(xt∣x~<t,x~≥t可能含噪,AR)\mathcal{L}{\text{ar}} = -\sum{t=1}^n \log p_\theta(x_t \mid \tilde{x}{<t}, \tilde{x}{\ge t} \text{可能含噪}, \texttt{AR})Lar=−∑t=1nlogpθ(xt∣x~<t,x~≥t可能含噪,AR)。
    该任务保留模型在部分信息错乱时的逐步推理和修正习惯,防止纯补全训练导致的退化。

两种模式的训练数据都依赖于同一个破坏序列 x~\tilde{x}x~,而 x~\tilde{x}x~ 由 模式 III 产生。

3.2 Mask 策略作为一个 LLM 任务

这是本文的核心创新:同一个 LLM 还将扮演破坏策略生成器的角色。

模式 III:Mask 策略生成(Mask Mode)

  • 输入:[MASK] 前缀 + 完整的正确序列 x=(x1,...,xn)x = (x_1, \dots, x_n)x=(x1,...,xn)。
  • 输出:对于每一个位置 ttt,模型生成一个动作 at∈{0,1,2}a_t \in \{0,1,2\}at∈{0,1,2},含义分别为:
    • 000:保留原 token xtx_txt;
    • 111:替换为 [MASK];
    • 222:替换为词表中随机采样的一个错误 token。
  • 动作序列 aaa 以自回归或非自回归方式生成。为方便端到端训练,我们采用非自回归概率参数化 :模型在输入 xxx 后,为每个位置 ttt 直接输出一个动作分布 pθ(at∣x,MASK)p_\theta(a_t \mid x, \texttt{MASK})pθ(at∣x,MASK)。实际实现中,只需在输出层添加一个三分类头(或复用 logits 经 softmax 得到概率),并与主 LM head 共享隐状态。

给定动作 aaa,我们即可按规则构造破坏序列 x~\tilde{x}x~:

  • x~t=xt\tilde{x}_t = x_tx~t=xt 若 at=0a_t = 0at=0;
  • x~t=MASK\tilde{x}_t = \text{MASK}x~t=MASK 若 at=1a_t = 1at=1;
  • x~t=random_token≠xt\tilde{x}_t = \text{random\_token} \neq x_tx~t=random_token=xt 若 at=2a_t = 2at=2.

我们希望优化 Mask 策略,使产生的 x~\tilde{x}x~ 在补全任务上能提供最有效的学习信号。直观上,一个好的破坏策略应当让补全损失 Lfill\mathcal{L}_{\text{fill}}Lfill 在保持一定难度(避免全部保留)的前提下尽可能小,或者从对抗视角看,让补全者"最挣扎"但又不至于崩溃。我们采用可微松弛直接优化这一目标,而无需强化学习。

端到端可微的破坏过程

离散动作 ata_tat 是不可微的。我们使用 Gumbel-Softmax 重参数化技巧:在训练时,对每个位置的类别 logits 加上 Gumbel 噪声,用软最大值近似采样,产生松弛的动作概率向量 a~t∈Δ2\tilde{a}_t \in \Delta^2a~t∈Δ2(三维单纯形)。然后,我们以一种可微的方式构造软破坏嵌入:

设 ekeep(xt)e_{\text{keep}}(x_t)ekeep(xt) 为保留 token 时的嵌入,emaske_{\text{mask}}emask 为 [MASK] 的嵌入,erande_{\text{rand}}erand 为从词汇分布采样并经过停止梯度或直通估计的随机错误 token 的嵌入。破坏后的输入嵌入为:

e~t=a~t,0⋅ekeep(xt)+a~t,1⋅emask+a~t,2⋅erand \tilde{e}t = \tilde{a}{t,0} \cdot e_{\text{keep}}(x_t) + \tilde{a}{t,1} \cdot e{\text{mask}} + \tilde{a}{t,2} \cdot e{\text{rand}} e~t=a~t,0⋅ekeep(xt)+a~t,1⋅emask+a~t,2⋅erand

其中 a~t,k\tilde{a}{t,k}a~t,k 是 Gumbel-Softmax 输出的软权重。接着,将 e~1:n\tilde{e}{1:n}e~1:n 送入补全模式(或自回归模式)计算相应损失。由于整个过程在嵌入空间可微,梯度可以从 Lfill\mathcal{L}{\text{fill}}Lfill 和 Lar\mathcal{L}{\text{ar}}Lar 一路回传至 Mask 策略的参数(即 LLM 中用于生成动作 logits 的权重)。

正则化与课程

完全自由地优化 Mask 策略可能导致它退化到极端:全保留(at=0a_t=0at=0)使任务平凡,或全破坏使任务不可能。因此,我们需要加入正则化项:

  • 破坏比例约束 :要求平均每个序列中 at∈{1,2}a_t \in \{1,2\}at∈{1,2} 的比例在 rmin⁡,rmax⁡r_{\\min}, r_{\\max}rmin,rmax 之间,例如 0.3,0.70.3, 0.70.3,0.7。可通过软约束惩罚实现。
  • 信息正则化 :最小化破坏序列与原始序列的互信息下限,或最大化补全损失的信息增益,例如惩罚 exp⁡(−Lfill)\exp(-\mathcal{L}_{\text{fill}})exp(−Lfill) 防止损失过小。
    这些正则化促使 Mask 策略自动去寻找"边界样本":既不能太简单,也不能完全摧毁信息。

最终,Mask 策略的训练损失可以写为:

Lmask=Lfill(e~)⏟鼓励有效破坏+λreg⋅R(a~)⏟比例/信息约束 \mathcal{L}{\text{mask}} = \underbrace{\mathcal{L}{\text{fill}}(\tilde{e})}{\text{鼓励有效破坏}} + \lambda{\text{reg}} \cdot \underbrace{\mathcal{R}(\tilde{a})}_{\text{比例/信息约束}} Lmask=鼓励有效破坏 Lfill(e~)+λreg⋅比例/信息约束 R(a~)

由于补全损失通过 Gumbel-Softmax 可微,该目标可直接通过梯度下降优化。注意,补全模型的参数 θ\thetaθ 在此也参与更新,但我们使用交替优化一次性联合反向传播:梯度同时流过补全分支和 Mask 分支,使得两者共同进化。

3.3 统一模型与训练流程

上述三种模式共享同一个 Transformer 主体。我们在输入开头添加三种特殊 token:[FILL][AR][MASK],它们作为前缀,在嵌入层有独立的嵌入向量,并在注意力机制中正常参与。训练时,每个 batch 采样下列混合数据:

  • 40% 的样本用于 补全任务 :采样自 Mask 策略在线生成的破坏序列,计算 Lfill\mathcal{L}_{\text{fill}}Lfill。
  • 30% 的样本用于 自回归任务 :使用同样的破坏序列(或重新采样),但输入标签为 [AR],计算 Lar\mathcal{L}_{\text{ar}}Lar。
  • 30% 的样本用于 Mask 策略更新 :输入正确序列,使用 Gumbel-Softmax 生成软破坏并立刻通过补全分支计算 Lfill\mathcal{L}{\text{fill}}Lfill,同时加入正则项,所得 Lmask\mathcal{L}{\text{mask}}Lmask 用于更新模型(包括 Mask head 和共享的 Transformer 参数)。

总损失为:Ltotal=αLfill+βLar+γLmask\mathcal{L}{\text{total}} = \alpha \mathcal{L}{\text{fill}} + \beta \mathcal{L}{\text{ar}} + \gamma \mathcal{L}{\text{mask}}Ltotal=αLfill+βLar+γLmask。为稳定训练,我们可在初期固定 Mask 策略(使用手工启发式规则)进行预热,随后逐渐引入 Gumbel-Softmax 自适应破坏。

3.4 推理:无需 Mask 策略的 Jacobi 解码

训练完成后,Mask 模式不再使用。推理仅依赖于模型在 [FILL][AR] 模式下培养出的综合能力,通过 Jacobi 迭代实现加速。

具体流程:

  1. 给定 prompt ppp,初始化草稿序列 s(0)=MASKns^{(0)} = \text{MASK}_ns(0)=MASKn。
  2. 第 kkk 步:将 [FILL] 前缀与 prompt + 当前草稿 s(k−1)s^{(k-1)}s(k−1) 拼接,送入模型进行一次前向传播,得到每个位置预测的 token 概率 pθ(⋅)p_\theta(\cdot)pθ(⋅)。
  3. 取概率最大的 token 构成预测序列 s^\hat{s}s^。
  4. 比较 s(k−1)s^{(k-1)}s(k−1) 与 s^\hat{s}s^:对于满足 st(k−1)=s^ts_t^{(k-1)} = \hat{s}_tst(k−1)=s^t 且预测置信度高于阈值 τ\tauτ 的位置,将其标记为"已收敛",在后续迭代中冻结(或使用缓存不再计算)。
  5. 未收敛的位置用 s^t\hat{s}_ts^t 更新,得到 s(k)s^{(k)}s(k)。
  6. 如果所有位置收敛或达到最大步数,停止;否则重复。

由于训练阶段,补全模式已经学会了从各种自适应破坏状态中直接恢复,加之自回归修正能力增强了局部一致性,收敛速度极快。实验模拟表明,平均每步可收敛约 2 个 token,因此对于长度 nnn 的序列,所需迭代步数 K≈n/2K \approx n/2K≈n/2,实现了约 2 倍的推理加速。更重要的是,这种加速在不同领域和长度上表现稳定,这得益于 Mask 策略训练出的模型更鲁棒的并行修正能力。


4. 实验设计(预期)

为验证 SMCL 的有效性,我们提出以下实验构想,可与基线 CLLM、标准自回归解码以及手工掩码的一致性训练进行对比。

模型 :基于 7B 规模的预训练 Transformer 进行微调。

数据集 :采用通用文本生成(如 WikiText-103)和对话任务(ShareGPT 等)进行实验。

评估指标

  • 生成质量:困惑度(PPL)、BLEU/ROUGE(对于特定任务)。
  • 加速比:达到相同输出所需的平均前向传播步数 KKK,加速比 = n/Kn/Kn/K。
  • 收敛稳定性:收敛步数的方差。
    对比方法
  • 原始自回归解码(AR)。
  • 标准 Jacobi 解码(无一致性训练)。
  • Consistency LLM (CLLM),使用手工随机替换 mask。
  • SMCL(本文方法),包括去除自适应 mask 的消融版本(仍用 Gumbel-Softmax 但冻结 Mask 策略为手工规则)。

预期结果

  • SMCL 在保持生成质量(PPL 与 AR 无显著差异)的前提下,稳定实现 1.8x-2.2x 的加速比,优于手工掩码的 CLLM(后者可能在 1.5x-1.9x 波动)。
  • 消融实验将显示,可学习 Mask 策略带来了更低的收敛步数方差和更好的域外泛化能力。
  • 训练过程中可观察到 Mask 策略的比例逐步从初始的均匀分布转向针对不同序列类型(如代码 vs 叙事)的结构化破坏,验证了自适应性。

5. 讨论

5.1 为什么可学习 Mask 能进一步提升加速?

手工掩码对所有序列一视同仁,但语言具有高度结构性:有些位置(如固定搭配)即使被遮蔽也极易猜出,有些(如开放命名实体)则需要更多上下文。自适应 Mask 策略可以放大模型不确定的区域,迫使补全任务聚焦在模型的薄弱环节。这种硬样本挖掘使得一致性的学习更加高效,最终体现在迭代时更快的收敛上。

5.2 与对抗训练和 GAN 的关系

SMCL 中的 Mask 策略与补全模型形成了一种合作-对抗关系:Mask 策略试图在正则化约束下最大化补全损失(或至少不使其变得平凡),而补全模型则试图最小化它。相比 GAN,我们没有引入判别器,而是通过相同的损失信号驱动两者,避免了判别器训练的不稳定,并且保持了单模型端到端训练的优势。

5.3 局限与未来工作

  • 训练开销:三种模式联合训练增加了每次迭代的计算量,且 Gumbel-Softmax 的软破坏在嵌入层引入了额外计算。不过在微调阶段,这仍可接受,因为最终推理的加速收益显著。
  • 超参数敏感度 :λreg\lambda_{\text{reg}}λreg、破坏比例范围等需要仔细调节,否则 Mask 策略可能失效。未来可探索更自适应正则化,如基于损失变化的自动权重。
  • 理论加速上限:Jacobi 解码的加速极限受序列中 token 之间依赖长度的影响,本文的 2 倍加速是经验平均值,对于长距离依赖强的文本可能略低。结合投机解码或树形验证等方法有望突破该界限。
  • 通用性:本方法目前针对文本生成设计,但框架可自然扩展到其他自回归生成任务,如代码补全、语音生成等。

6. 结论

本文提出了自掩码一致性学习 (SMCL)------一个将大语言模型生成加速推向完全自适应的统一框架。通过将"如何破坏训练数据"建模为同一 LLM 的可学习任务,我们使得一致性训练中的掩码策略能够与补全/自回归能力协同进化,从而在不需要任何手工规则的前提下,使模型在 Jacobi 迭代解码中稳定实现近 2 倍的加速。我们详细给出了三种模式的前缀指令设计、Gumbel-Softmax 可微破坏管道以及联合训练方案。这一框架为提升大语言模型推理效率提供了一种新的端到端学习范式,也为未来自动课程设计和生成模型自我改进的研究开辟了道路。


参考文献

1 Santilli et al., "Accelerating Transformer Inference for Translation via Parallel Decoding", ACL 2023.

2 Kou et al., "CLLMs: Consistency Large Language Models", ICML 2024.

3 Leviathan et al., "Fast Inference from Transformers via Speculative Decoding", ICML 2023.

4 Jang et al., "Categorical Reparameterization with Gumbel-Softmax", ICLR 2017.

5 Chen et al., "Accelerating Large Language Model Decoding with Speculative Sampling", arXiv 2023.


本博客旨在分享前沿技术思路,相关实验正在推进中。欢迎在评论区留言讨论。

相关推荐
Ztopcloud极拓云视角12 小时前
ChatGPT超级应用改版技术解析:Codex集成架构与多模型路由实战
人工智能·chatgpt·架构
AOwhisky19 小时前
Redis 学习笔记(第三期):持久化与主从复制
运维·数据库·redis·笔记·学习·云计算
秋919 小时前
从 Python 后端工程师转型 AI Engineer(AI 工程化)的完整补课清单(2026实战版)
开发语言·人工智能·python
啦啦啦_999919 小时前
5. 迁移学习
人工智能·机器学习·迁移学习
A.说学逗唱的Coke19 小时前
【AI·Coding】TDD × SDD × AI Coding:从“测试驱动“到“规范驱动“的智能协作实践
人工智能·驱动开发·tdd
云烟成雨TD19 小时前
Spring AI Alibaba 1.x 系列【78】沙箱(Sandbox)
java·人工智能·spring
tq108619 小时前
基于SLIP的防幻觉的指南
人工智能
Tbisnic20 小时前
AI大模型学习第十一天:技术选型、安全防护与金融实战
python·学习·ai·大模型·提示词工程
甲维斯20 小时前
Kimi版超级玛丽效果“惊人”,配额不足5厘米!
前端·人工智能
console.log('npc')21 小时前
AI前端工程与生成式UI学习路线
前端·人工智能·ui