知识蒸馏:理论、算法与可运行实现

目录

  • 第一部分:知识蒸馏基础理论
    • 第一章:绪论------模型压缩的第三条路
    • 第二章:知识蒸馏的数学基础------散度、温度与信息论
    • 第三章:Hinton 蒸馏的理论分析
  • 第二部分:经典蒸馏方法
    • 第四章:响应蒸馏(Response-based Distillation)
    • 第五章:特征蒸馏(Feature-based Distillation)
    • 第六章:关系蒸馏(Relation-based Distillation)
  • 第三部分:自蒸馏与无教师蒸馏
    • 第七章:自蒸馏(Self-Distillation)
    • 第八章:无教师蒸馏与数据蒸馏
  • 第四部分:大语言模型的知识蒸馏
    • 第九章:LLM 蒸馏的特殊挑战
    • 第十章:白盒蒸馏------逐层与逐 token 对齐
    • 第十一章:黑盒蒸馏------从 API 到合成数据
    • 第十二章:思维链蒸馏与推理能力迁移
  • 第五部分:完整可运行代码实现
    • 第十三章:从零实现经典知识蒸馏
    • 第十四章:从零实现特征蒸馏
    • 第十五章:从零实现自蒸馏
    • 第十六章:完整蒸馏 Pipeline 与精度对比
  • 附录

第一部分:知识蒸馏基础理论


第一章:绪论------模型压缩的第三条路

1.1 模型压缩的三大范式

在前两篇文档中,我们分别讨论了量化剪枝。知识蒸馏是模型压缩的第三大范式:

范式 核心思想 压缩方式 是否需要训练
量化 降低参数精度 减少每个参数的比特数 可选(PTQ/QAT)
剪枝 移除冗余参数 减少非零参数数量 可选
蒸馏 知识迁移 训练更小的模型 必须

1.1.1 知识蒸馏的核心思想

知识蒸馏(Knowledge Distillation, KD) (Hinton et al., 2015)的核心思想是:用一个大的教师模型(Teacher) 来指导一个小的学生模型(Student) 训练,使学生模型学习到教师模型的"知识"。

关键洞察 :教师模型的软标签(soft labels)------即 softmax 输出的概率分布------包含了比硬标签(one-hot)更丰富的信息。

例子:对于图像分类任务,一张"猫"的图片:

  • 硬标签0,0,1,0,00, 0, 1, 0, 00,0,1,0,0(只有猫是 1)
  • 软标签0.01,0.02,0.90,0.05,0.020.01, 0.02, 0.90, 0.05, 0.020.01,0.02,0.90,0.05,0.02(猫 0.90,但也有一定概率是其他动物)

软标签告诉学生:"这张图片虽然是猫,但和狗也有一定相似性"------这是硬标签无法传达的。

1.1.2 蒸馏 vs 量化 vs 剪枝

特性 量化 剪枝 蒸馏
压缩比 2x-4x 2x-10x 5x-100x
精度保持 中等
计算加速 中等-好 最好
实现复杂度 中等
是否改变架构 可选
是否需要训练 可选 可选 必须

蒸馏的独特优势:可以改变模型架构------不仅减少参数量,还可以简化模型结构(如减少层数、隐藏维度等)。

1.2 蒸馏的历史

1.2.1 模型压缩的早期工作

模型集成(Model Ensemble) :多个模型的预测平均通常优于单个模型。但集成模型的计算成本是单模型的 KKK 倍。

Buciluǎ et al., 2006:首次提出将集成模型的知识压缩到单个模型中。

1.2.2 Hinton 的贡献

Hinton et al., 2015:提出了温度缩放的 softmax 和蒸馏损失函数,奠定了现代知识蒸馏的基础。

关键创新

  1. 温度缩放 :使用温度参数 TTT 来"软化"softmax 输出
  2. 蒸馏损失:学生模型同时学习硬标签和软标签
  3. 暗知识(Dark Knowledge):软标签中包含的类间关系信息

第二章:知识蒸馏的数学基础------散度、温度与信息论

2.1 概率分布的距离度量

2.1.1 KL 散度

定义 2.1(Kullback-Leibler 散度) :对于两个离散概率分布 PPP 和 QQQ,KL 散度定义为:

DKL(P∥Q)=∑iP(i)log⁡P(i)Q(i)D_{\text{KL}}(P \| Q) = \sum_i P(i) \log \frac{P(i)}{Q(i)}DKL(P∥Q)=i∑P(i)logQ(i)P(i)

性质

  1. DKL(P∥Q)≥0D_{\text{KL}}(P \| Q) \geq 0DKL(P∥Q)≥0(非负性)
  2. DKL(P∥Q)=0  ⟺  P=QD_{\text{KL}}(P \| Q) = 0 \iff P = QDKL(P∥Q)=0⟺P=Q(同一性)
  3. DKL(P∥Q)≠DKL(Q∥P)D_{\text{KL}}(P \| Q) \neq D_{\text{KL}}(Q \| P)DKL(P∥Q)=DKL(Q∥P)(不对称性)

注意:KL 散度不是真正的距离度量(因为不对称)。

2.1.2 交叉熵

定义 2.2(交叉熵)

H(P,Q)=−∑iP(i)log⁡Q(i)H(P, Q) = -\sum_i P(i) \log Q(i)H(P,Q)=−i∑P(i)logQ(i)

与 KL 散度的关系

DKL(P∥Q)=H(P,Q)−H(P)D_{\text{KL}}(P \| Q) = H(P, Q) - H(P)DKL(P∥Q)=H(P,Q)−H(P)

其中 H(P)=−∑iP(i)log⁡P(i)H(P) = -\sum_i P(i) \log P(i)H(P)=−∑iP(i)logP(i) 是 PPP 的熵。

推论 :最小化 DKL(P∥Q)D_{\text{KL}}(P \| Q)DKL(P∥Q) 等价于最小化 H(P,Q)H(P, Q)H(P,Q)(因为 H(P)H(P)H(P) 是常数)。

2.1.3 JS 散度

定义 2.3(Jensen-Shannon 散度)

DJS(P∥Q)=12DKL(P∥M)+12DKL(Q∥M)D_{\text{JS}}(P \| Q) = \frac{1}{2} D_{\text{KL}}(P \| M) + \frac{1}{2} D_{\text{KL}}(Q \| M)DJS(P∥Q)=21DKL(P∥M)+21DKL(Q∥M)

其中 M=(P+Q)/2M = (P + Q) / 2M=(P+Q)/2。

性质

  1. DJS(P∥Q)≥0D_{\text{JS}}(P \| Q) \geq 0DJS(P∥Q)≥0
  2. DJS(P∥Q)=DJS(Q∥P)D_{\text{JS}}(P \| Q) = D_{\text{JS}}(Q \| P)DJS(P∥Q)=DJS(Q∥P)(对称性)
  3. DJS(P∥Q)≤log⁡2D_{\text{JS}}(P \| Q) \leq \log 2DJS(P∥Q)≤log2(有界性)

2.1.4 散度选择的理论依据

定理 2.1(前向 vs 反向 KL)

在知识蒸馏中,使用前向 KL (DKL(PT∥PS)D_{\text{KL}}(P_T \| P_S)DKL(PT∥PS))还是反向 KL (DKL(PS∥PT)D_{\text{KL}}(P_S \| P_T)DKL(PS∥PT))会导致不同的行为:

  • 前向 KL (DKL(PT∥PS)D_{\text{KL}}(P_T \| P_S)DKL(PT∥PS)):学生需要覆盖教师的所有模式(mode-covering)
  • 反向 KL (DKL(PS∥PT)D_{\text{KL}}(P_S \| P_T)DKL(PS∥PT)):学生倾向于集中在教师的一个模式上(mode-seeking)

证明 :设教师分布 PTP_TPT 是双峰的,PT=0.5⋅N(μ1,σ2)+0.5⋅N(μ2,σ2)P_T = 0.5 \cdot \mathcal{N}(\mu_1, \sigma^2) + 0.5 \cdot \mathcal{N}(\mu_2, \sigma^2)PT=0.5⋅N(μ1,σ2)+0.5⋅N(μ2,σ2)。

前向 KL:DKL(PT∥PS)=−H(PT)+H(PT,PS)D_{\text{KL}}(P_T \| P_S) = -H(P_T) + H(P_T, P_S)DKL(PT∥PS)=−H(PT)+H(PT,PS)

当 PSP_SPS 是单峰高斯时,为了最小化 H(PT,PS)H(P_T, P_S)H(PT,PS),PSP_SPS 需要覆盖两个峰------导致 PSP_SPS 被"拉宽"。

反向 KL:DKL(PS∥PT)=−H(PS)+H(PS,PT)D_{\text{KL}}(P_S \| P_T) = -H(P_S) + H(P_S, P_T)DKL(PS∥PT)=−H(PS)+H(PS,PT)

当 PSP_SPS 是单峰高斯时,为了最小化 H(PS,PT)H(P_S, P_T)H(PS,PT),PSP_SPS 会"锁定"到其中一个峰。□\square□

实际意义 :在知识蒸馏中,通常使用反向 KL (即最小化 DKL(PS∥PT)D_{\text{KL}}(P_S \| P_T)DKL(PS∥PT)),因为这等价于最小化学生分布相对于教师分布的交叉熵。

2.2 温度缩放的 Softmax

2.2.1 标准 Softmax

定义 2.4(Softmax 函数) :对于 logits z=(z1,...,zK)\mathbf{z} = (z_1, \dots, z_K)z=(z1,...,zK):

pi=softmax(zi)=ezi∑j=1Kezjp_i = \text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^K e^{z_j}}pi=softmax(zi)=∑j=1Kezjezi

2.2.2 温度缩放

定义 2.5(温度缩放的 Softmax)

pi(T)=softmax(zi/T)=ezi/T∑j=1Kezj/Tp_i^{(T)} = \text{softmax}(z_i / T) = \frac{e^{z_i / T}}{\sum_{j=1}^K e^{z_j / T}}pi(T)=softmax(zi/T)=∑j=1Kezj/Tezi/T

其中 T>0T > 0T>0 是温度(temperature) 参数。

温度的影响

  • T=1T = 1T=1:标准 softmax
  • T>1T > 1T>1:分布更"软"(更均匀),类间差异被平滑
  • T<1T < 1T<1:分布更"硬"(更尖锐),类间差异被放大
  • T→∞T \to \inftyT→∞:均匀分布 pi=1/Kp_i = 1/Kpi=1/K
  • T→0T \to 0T→0:one-hot 分布(argmax)

2.2.3 温度的信息论解释

定理 2.2(温度与熵的关系):温度缩放的 softmax 输出的熵为:

H(T)=−∑ipi(T)log⁡pi(T)H(T) = -\sum_i p_i^{(T)} \log p_i^{(T)}H(T)=−i∑pi(T)logpi(T)

且:

  • H(T)H(T)H(T) 是 TTT 的单调递增函数
  • H(0)=0H(0) = 0H(0)=0(one-hot,最小熵)
  • H(∞)=log⁡KH(\infty) = \log KH(∞)=logK(均匀分布,最大熵)

证明 :对 H(T)H(T)H(T) 关于 TTT 求导:

dHdT=1T2Varp(T)z≥0\frac{dH}{dT} = \frac{1}{T^2} \text{Var}_{p^{(T)}}z \geq 0dTdH=T21Varp(T)z≥0

其中 Varp(T)z=∑ipi(T)zi2−(∑ipi(T)zi)2\text{Var}_{p^{(T)}}z = \sum_i p_i^{(T)} z_i^2 - (\sum_i p_i^{(T)} z_i)^2Varp(T)z=∑ipi(T)zi2−(∑ipi(T)zi)2 是 logits 在分布 p(T)p^{(T)}p(T) 下的方差。□\square□

推论 2.1 :温度 TTT 控制了软标签中包含的"信息量"------TTT 越大,软标签越均匀,包含的信息越少(熵越大);TTT 越小,软标签越尖锐,包含的信息越多。

2.2.4 最优温度的选择

问题 :如何选择最优的温度 T∗T^*T∗?

定理 2.3(最优温度的信息论准则) :最优温度 T∗T^*T∗ 应使教师软标签的熵处于一个"合适"的范围:

Htarget≤H(T∗)≤Htarget+ΔH_{\text{target}} \leq H(T^*) \leq H_{\text{target}} + \DeltaHtarget≤H(T∗)≤Htarget+Δ

其中 HtargetH_{\text{target}}Htarget 是目标熵,Δ\DeltaΔ 是容许的熵范围。

直觉

  • TTT 太小:软标签接近硬标签,"暗知识"丢失
  • TTT 太大:软标签过于均匀,有用信息被稀释
  • 最优 TTT 通常在 2-20 之间

2.3 知识蒸馏的信息论框架

2.3.1 蒸馏作为信息传递

定理 2.4(蒸馏的信息瓶颈) :设教师模型的输出分布为 PTP_TPT,学生模型的输出分布为 PSP_SPS。蒸馏过程可以看作是最大化 PSP_SPS 和 PTP_TPT 之间的互信息:

max⁡PSI(PS;PT)=H(PS)−H(PS∣PT)\max_{P_S} I(P_S; P_T) = H(P_S) - H(P_S | P_T)PSmaxI(PS;PT)=H(PS)−H(PS∣PT)

在完美蒸馏的情况下,H(PS∣PT)=0H(P_S | P_T) = 0H(PS∣PT)=0,即 PSP_SPS 完全由 PTP_TPT 决定。

2.3.2 蒸馏的率失真分析

定理 2.5(蒸馏的率失真下界) :设教师模型有 KKK 个类别,学生模型使用 bbb 比特编码每个样本的预测。则蒸馏后的学生模型与教师模型之间的 KL 散度满足:

DKL(PT∥PS)≥12log⁡22πeK−bD_{\text{KL}}(P_T \| P_S) \geq \frac{1}{2} \log_2 \frac{2\pi e}{K} - bDKL(PT∥PS)≥21log2K2πe−b

推论 :要达到很小的 KL 散度,学生模型需要足够大的"容量"(由 bbb 或等价地由模型参数量决定)。


第三章:Hinton 蒸馏的理论分析

3.1 Hinton 蒸馏的损失函数

3.1.1 标准蒸馏损失

Hinton et al., 2015 提出的蒸馏损失函数为:

L=α⋅Lhard+(1−α)⋅Lsoft\mathcal{L} = \alpha \cdot \mathcal{L}{\text{hard}} + (1 - \alpha) \cdot \mathcal{L}{\text{soft}}L=α⋅Lhard+(1−α)⋅Lsoft

其中:

硬标签损失(交叉熵):

Lhard=H(y,PS(1))=−∑iyilog⁡PS,i(1)\mathcal{L}{\text{hard}} = H(\mathbf{y}, P_S^{(1)}) = -\sum_i y_i \log P{S,i}^{(1)}Lhard=H(y,PS(1))=−i∑yilogPS,i(1)

软标签损失 (KL 散度,温度 TTT):

Lsoft=T2⋅DKL(PT(T)∥PS(T))=T2⋅∑iPT,i(T)log⁡PT,i(T)PS,i(T)\mathcal{L}{\text{soft}} = T^2 \cdot D{\text{KL}}(P_T^{(T)} \| P_S^{(T)}) = T^2 \cdot \sum_i P_{T,i}^{(T)} \log \frac{P_{T,i}^{(T)}}{P_{S,i}^{(T)}}Lsoft=T2⋅DKL(PT(T)∥PS(T))=T2⋅i∑PT,i(T)logPS,i(T)PT,i(T)

其中 y\mathbf{y}y 是真实标签(one-hot),PT(T)P_T^{(T)}PT(T) 和 PS(T)P_S^{(T)}PS(T) 分别是温度 TTT 下教师和学生的 softmax 输出。

3.1.2 为什么乘以 T2T^2T2?

定理 3.1(T2T^2T2 缩放的必要性) :在温度 TTT 下,softmax 输出的梯度量级约为 1/T1/T1/T。因此 KL 散度的梯度量级约为 1/T21/T^21/T2。乘以 T2T^2T2 可以使软标签损失的梯度量级与硬标签损失一致。

证明 :设 logits 为 z\mathbf{z}z,温度缩放的 softmax 为 pi=ezi/T/∑jezj/Tp_i = e^{z_i/T} / \sum_j e^{z_j/T}pi=ezi/T/∑jezj/T。

对 ziz_izi 求导:

∂pi∂zi=1Tpi(1−pi)\frac{\partial p_i}{\partial z_i} = \frac{1}{T} p_i (1 - p_i)∂zi∂pi=T1pi(1−pi)

∂pj∂zi=−1Tpipj,j≠i\frac{\partial p_j}{\partial z_i} = -\frac{1}{T} p_i p_j, \quad j \neq i∂zi∂pj=−T1pipj,j=i

因此 ∂p/∂z=O(1/T)\partial p / \partial z = O(1/T)∂p/∂z=O(1/T)。

KL 散度的梯度:∂DKL∂z=O(1/T)\frac{\partial D_{\text{KL}}}{\partial z} = O(1/T)∂z∂DKL=O(1/T)。

乘以 T2T^2T2 后:T2⋅∂DKL∂z=O(T)T^2 \cdot \frac{\partial D_{\text{KL}}}{\partial z} = O(T)T2⋅∂z∂DKL=O(T)------当 TTT 较大时,梯度量级约为 TTT。

实际上,更精确的分析表明 T2T^2T2 缩放使得软标签损失的梯度在 TTT 变化时保持相对稳定。□\square□

3.1.3 α\alphaα 的选择

经验规则

  • α=0.1\alpha = 0.1α=0.1:更重视软标签(常见选择)
  • α=0.5\alpha = 0.5α=0.5:软标签和硬标签同等重要
  • α=0.9\alpha = 0.9α=0.9:更重视硬标签

定理 3.2(α\alphaα 的最优值) :在一定假设下,最优的 α\alphaα 满足:

α∗=H(PT(T))H(PT(T))+H(y)\alpha^* = \frac{H(P_T^{(T)})}{H(P_T^{(T)}) + H(\mathbf{y})}α∗=H(PT(T))+H(y)H(PT(T))

其中 H(PT(T))H(P_T^{(T)})H(PT(T)) 是教师软标签的熵,H(y)=0H(\mathbf{y}) = 0H(y)=0(硬标签的熵为零)。

因此 α∗→0\alpha^* \to 0α∗→0------理论上应该完全使用软标签。但实践中,保留一定比例的硬标签有助于稳定训练。

3.2 蒸馏的梯度分析

3.2.1 软标签损失的梯度

定理 3.3(软标签梯度的分解) :软标签损失关于学生 logits ziz_izi 的梯度为:

∂Lsoft∂zi=T⋅(PS,i(T)−PT,i(T))\frac{\partial \mathcal{L}{\text{soft}}}{\partial z_i} = T \cdot (P{S,i}^{(T)} - P_{T,i}^{(T)})∂zi∂Lsoft=T⋅(PS,i(T)−PT,i(T))

证明

Lsoft=T2∑jPT,j(T)log⁡PT,j(T)PS,j(T)=T2(∑jPT,j(T)log⁡PT,j(T)−∑jPT,j(T)log⁡PS,j(T))\mathcal{L}{\text{soft}} = T^2 \sum_j P{T,j}^{(T)} \log \frac{P_{T,j}^{(T)}}{P_{S,j}^{(T)}} = T^2 \left(\sum_j P_{T,j}^{(T)} \log P_{T,j}^{(T)} - \sum_j P_{T,j}^{(T)} \log P_{S,j}^{(T)}\right)Lsoft=T2j∑PT,j(T)logPS,j(T)PT,j(T)=T2(j∑PT,j(T)logPT,j(T)−j∑PT,j(T)logPS,j(T))

第一项与学生参数无关。对第二项求导:

∂∂zi(−T2∑jPT,j(T)log⁡PS,j(T))=−T2∑jPT,j(T)1PS,j(T)∂PS,j(T)∂zi\frac{\partial}{\partial z_i} \left(-T^2 \sum_j P_{T,j}^{(T)} \log P_{S,j}^{(T)}\right) = -T^2 \sum_j P_{T,j}^{(T)} \frac{1}{P_{S,j}^{(T)}} \frac{\partial P_{S,j}^{(T)}}{\partial z_i}∂zi∂(−T2j∑PT,j(T)logPS,j(T))=−T2j∑PT,j(T)PS,j(T)1∂zi∂PS,j(T)

使用 ∂PS,j(T)∂zi=1TPS,j(T)(δij−PS,i(T))\frac{\partial P_{S,j}^{(T)}}{\partial z_i} = \frac{1}{T} P_{S,j}^{(T)} (\delta_{ij} - P_{S,i}^{(T)})∂zi∂PS,j(T)=T1PS,j(T)(δij−PS,i(T)):

=−T2∑jPT,j(T)1PS,j(T)⋅1TPS,j(T)(δij−PS,i(T))= -T^2 \sum_j P_{T,j}^{(T)} \frac{1}{P_{S,j}^{(T)}} \cdot \frac{1}{T} P_{S,j}^{(T)} (\delta_{ij} - P_{S,i}^{(T)})=−T2j∑PT,j(T)PS,j(T)1⋅T1PS,j(T)(δij−PS,i(T))

=−T∑jPT,j(T)(δij−PS,i(T))=−T(PT,i(T)−PS,i(T)∑jPT,j(T))= -T \sum_j P_{T,j}^{(T)} (\delta_{ij} - P_{S,i}^{(T)}) = -T (P_{T,i}^{(T)} - P_{S,i}^{(T)} \sum_j P_{T,j}^{(T)})=−Tj∑PT,j(T)(δij−PS,i(T))=−T(PT,i(T)−PS,i(T)j∑PT,j(T))

=T(PS,i(T)−PT,i(T))= T (P_{S,i}^{(T)} - P_{T,i}^{(T)})=T(PS,i(T)−PT,i(T))

□\square□

3.2.2 软标签梯度的物理含义

推论 3.1 :软标签损失的梯度与学生和教师输出的差异成正比------差异越大,梯度越大,学生更新越快。

推论 3.2 :当 TTT 很大时,PS,i(T)≈PT,i(T)≈1/KP_{S,i}^{(T)} \approx P_{T,i}^{(T)} \approx 1/KPS,i(T)≈PT,i(T)≈1/K,梯度接近零------高温下蒸馏的"学习信号"很弱。

推论 3.3 :当 TTT 很小时,PS,i(T)P_{S,i}^{(T)}PS,i(T) 和 PT,i(T)P_{T,i}^{(T)}PT,i(T) 接近 one-hot,梯度变得稀疏------只有被错误分类的样本有非零梯度。

最优温度:在"信息量"和"梯度强度"之间取得平衡。

3.3 蒸馏的泛化分析

3.3.1 蒸馏作为正则化

定理 3.4(蒸馏的正则化效应):蒸馏损失等价于在标准交叉熵损失上添加一个正则化项:

L=H(y,PS)+λ⋅DKL(PT∥PS)\mathcal{L} = H(\mathbf{y}, P_S) + \lambda \cdot D_{\text{KL}}(P_T \| P_S)L=H(y,PS)+λ⋅DKL(PT∥PS)

其中 λ=(1−α)T2/α\lambda = (1 - \alpha) T^2 / \alphaλ=(1−α)T2/α。

推论 :蒸馏的正则化效应类似于标签平滑(Label Smoothing)------将硬标签替换为软标签。

3.3.2 标签平滑与蒸馏的关系

定义 3.1(标签平滑)

y~i=(1−ϵ)yi+ϵK\tilde{y}_i = (1 - \epsilon) y_i + \frac{\epsilon}{K}y~i=(1−ϵ)yi+Kϵ

其中 ϵ\epsilonϵ 是平滑参数。

定理 3.5(蒸馏与标签平滑的等价性):当教师模型的输出是均匀分布时,蒸馏退化为标签平滑:

PT,i(T)=1K  ⟹  Lsoft=T2DKL(1K∥PS(T))=T2(log⁡K+∑i1Klog⁡PS,i(T))P_{T,i}^{(T)} = \frac{1}{K} \implies \mathcal{L}{\text{soft}} = T^2 D{\text{KL}}\left(\frac{1}{K} \| P_S^{(T)}\right) = T^2 \left(\log K + \sum_i \frac{1}{K} \log P_{S,i}^{(T)}\right)PT,i(T)=K1⟹Lsoft=T2DKL(K1∥PS(T))=T2(logK+i∑K1logPS,i(T))

这与标签平滑的损失形式一致。□\square□

关键区别 :教师模型的输出不是均匀分布------它包含了类间的结构信息。这正是蒸馏比标签平滑更有效的原因。


第二部分:经典蒸馏方法


第四章:响应蒸馏(Response-based Distillation)

4.1 基本框架

4.1.1 问题形式化

给定教师模型 fTf_TfT 和学生模型 fSf_SfS,输入 x\mathbf{x}x,类别数 KKK:

  • 教师 logits:zT=fT(x)∈RK\mathbf{z}_T = f_T(\mathbf{x}) \in \mathbb{R}^KzT=fT(x)∈RK
  • 学生 logits:zS=fS(x)∈RK\mathbf{z}_S = f_S(\mathbf{x}) \in \mathbb{R}^KzS=fS(x)∈RK
  • 教师软标签:pT=softmax(zT/T)\mathbf{p}_T = \text{softmax}(\mathbf{z}_T / T)pT=softmax(zT/T)
  • 学生软标签:pS=softmax(zS/T)\mathbf{p}_S = \text{softmax}(\mathbf{z}_S / T)pS=softmax(zS/T)

4.1.2 蒸馏损失

标准 KL 蒸馏

LKD=T2⋅DKL(pT∥pS)\mathcal{L}{\text{KD}} = T^2 \cdot D{\text{KL}}(\mathbf{p}_T \| \mathbf{p}_S)LKD=T2⋅DKL(pT∥pS)

交叉熵蒸馏(等价):

LCE=−T2∑ipT,ilog⁡pS,i\mathcal{L}{\text{CE}} = -T^2 \sum_i p{T,i} \log p_{S,i}LCE=−T2i∑pT,ilogpS,i

(因为 DKL(P∥Q)=H(P,Q)−H(P)D_{\text{KL}}(P \| Q) = H(P, Q) - H(P)DKL(P∥Q)=H(P,Q)−H(P),H(P)H(P)H(P) 是常数。)

4.1.3 总损失

Ltotal=α⋅H(y,pS(1))+(1−α)⋅T2⋅DKL(pT(T)∥pS(T))\mathcal{L}_{\text{total}} = \alpha \cdot H(\mathbf{y}, \mathbf{p}S^{(1)}) + (1 - \alpha) \cdot T^2 \cdot D{\text{KL}}(\mathbf{p}_T^{(T)} \| \mathbf{p}_S^{(T)})Ltotal=α⋅H(y,pS(1))+(1−α)⋅T2⋅DKL(pT(T)∥pS(T))

4.2 响应蒸馏的变体

4.2.1 MSE 蒸馏

定义:直接最小化 logits 之间的均方误差:

LMSE=∥zT−zS∥2\mathcal{L}_{\text{MSE}} = \|\mathbf{z}_T - \mathbf{z}_S\|^2LMSE=∥zT−zS∥2

与 KL 蒸馏的比较

定理 4.1(MSE 与 KL 的关系):在 logits 的小扰动假设下:

DKL(pT∥pS)≈12T2∑ipT,i(zT,i−zS,i)2−12T2(∑ipT,i(zT,i−zS,i))2D_{\text{KL}}(\mathbf{p}T \| \mathbf{p}S) \approx \frac{1}{2T^2} \sum_i p{T,i} (z{T,i} - z_{S,i})^2 - \frac{1}{2T^2} \left(\sum_i p_{T,i} (z_{T,i} - z_{S,i})\right)^2DKL(pT∥pS)≈2T21i∑pT,i(zT,i−zS,i)2−2T21(i∑pT,i(zT,i−zS,i))2

当 pT\mathbf{p}TpT 接近均匀分布时,DKL≈12T2VarpTzT−zSD{\text{KL}} \approx \frac{1}{2T^2} \text{Var}_{\mathbf{p}_T}\\mathbf{z}_T - \\mathbf{z}_SDKL≈2T21VarpTzT−zS

推论:MSE 蒸馏和 KL 蒸馏在某些条件下近似等价,但 KL 蒸馏更鲁棒。

4.2.2 残差蒸馏

定义 :学生学习教师的残差而非绝对值:

Lres=∥(zT−zS)−r(zS)∥2\mathcal{L}_{\text{res}} = \|(\mathbf{z}_T - \mathbf{z}_S) - r(\mathbf{z}_S)\|^2Lres=∥(zT−zS)−r(zS)∥2

其中 rrr 是一个可学习的残差网络。

4.2.3 对抗蒸馏

定义:使用对抗训练来增强蒸馏:

min⁡fSmax⁡DLKD(fT,fS)+λLadv(fS,D)\min_{f_S} \max_{D} \mathcal{L}{\text{KD}}(f_T, f_S) + \lambda \mathcal{L}{\text{adv}}(f_S, D)fSminDmaxLKD(fT,fS)+λLadv(fS,D)

其中 DDD 是判别器,试图区分教师和学生的输出。


第五章:特征蒸馏(Feature-based Distillation)

5.1 中间层特征对齐

5.1.1 动机

问题:响应蒸馏只利用了教师模型的最终输出,忽略了中间层的丰富信息。

特征蒸馏:对齐教师和学生中间层的特征表示。

5.1.2 FitNets

Romero et al., 2015 提出的 FitNets 是最早的特征蒸馏方法。

定义 5.1(FitNet 损失)

LFitNet=∥fT(l)(x)−r(fS(l′)(x))∥2\mathcal{L}_{\text{FitNet}} = \|f_T^{(l)}(\mathbf{x}) - r(f_S^{(l')}(\mathbf{x}))\|^2LFitNet=∥fT(l)(x)−r(fS(l′)(x))∥2

其中:

  • fT(l)f_T^{(l)}fT(l) 是教师第 lll 层的特征
  • fS(l′)f_S^{(l')}fS(l′) 是学生第 l′l'l′ 层的特征
  • rrr 是一个可学习的回归网络(用于维度匹配)

问题 :教师和学生的特征维度可能不同,需要 rrr 来匹配维度。

5.1.3 特征蒸馏的理论分析

定理 5.1(特征蒸馏的信息论解释):特征蒸馏等价于最大化教师特征和学生特征之间的互信息:

max⁡fSI(fT(l)(x);fS(l′)(x))\max_{f_S} I(f_T^{(l)}(\mathbf{x}); f_S^{(l')}(\mathbf{x}))fSmaxI(fT(l)(x);fS(l′)(x))

证明 :设 hT=fT(l)(x)\mathbf{h}_T = f_T^{(l)}(\mathbf{x})hT=fT(l)(x),hS=fS(l′)(x)\mathbf{h}_S = f_S^{(l')}(\mathbf{x})hS=fS(l′)(x)。在高斯假设下:

I(hT;hS)=12log⁡∣Cov(hT)∣∣Cov(hT∣hS)∣I(\mathbf{h}_T; \mathbf{h}_S) = \frac{1}{2} \log \frac{|\text{Cov}(\mathbf{h}_T)|}{|\text{Cov}(\mathbf{h}_T | \mathbf{h}_S)|}I(hT;hS)=21log∣Cov(hT∣hS)∣∣Cov(hT)∣

最小化 MSE ∥hT−hS∥2\|\mathbf{h}_T - \mathbf{h}_S\|^2∥hT−hS∥2 等价于最大化条件概率 p(hT∣hS)p(\mathbf{h}_T | \mathbf{h}_S)p(hT∣hS) 的集中度,即最大化互信息。□\square□

5.2 特征蒸馏的变体

5.2.1 注意力转移(Attention Transfer)

Zagoruyko & Komodakis, 2017 提出对齐注意力图:

定义 5.2(注意力图) :对于特征图 A∈RC×H×W\mathbf{A} \in \mathbb{R}^{C \times H \times W}A∈RC×H×W,注意力图定义为:

A(A)=1C∑c=1C∣Ac∣p\mathcal{A}(\mathbf{A}) = \frac{1}{C} \sum_{c=1}^{C} |A_c|^pA(A)=C1c=1∑C∣Ac∣p

其中 ppp 是幂次参数(通常 p=2p = 2p=2)。

注意力转移损失

LAT=∑l∥A(AT(l))∥A(AT(l))∥2−A(AS(l))∥A(AS(l))∥2∥22\mathcal{L}_{\text{AT}} = \sum_l \left\|\frac{\mathcal{A}(\mathbf{A}_T^{(l)})}{\|\mathcal{A}(\mathbf{A}_T^{(l)})\|_2} - \frac{\mathcal{A}(\mathbf{A}_S^{(l)})}{\|\mathcal{A}(\mathbf{A}_S^{(l)})\|_2}\right\|_2^2LAT=l∑ ∥A(AT(l))∥2A(AT(l))−∥A(AS(l))∥2A(AS(l)) 22

5.2.2 特征分布对齐

定义 5.3(分布对齐损失)

Ldist=∑lDKL(p(hT(l))∥p(hS(l)))\mathcal{L}{\text{dist}} = \sum_l D{\text{KL}}(p(\mathbf{h}_T^{(l)}) \| p(\mathbf{h}_S^{(l)}))Ldist=l∑DKL(p(hT(l))∥p(hS(l)))

在高斯假设下,p(h)=N(μ,Σ)p(\mathbf{h}) = \mathcal{N}(\mu, \Sigma)p(h)=N(μ,Σ),KL 散度有解析形式:

DKL(NT∥NS)=12tr(ΣS−1ΣT)−k+(μS−μT)TΣS−1(μS−μT)+log⁡∣ΣS∣∣ΣT∣D_{\text{KL}}(\mathcal{N}_T \| \mathcal{N}_S) = \frac{1}{2}\left\\text{tr}(\\Sigma_S\^{-1} \\Sigma_T) - k + (\\mu_S - \\mu_T)\^T \\Sigma_S\^{-1} (\\mu_S - \\mu_T) + \\log \\frac{\|\\Sigma_S\|}{\|\\Sigma_T\|}\\rightDKL(NT∥NS)=21tr(ΣS−1ΣT)−k+(μS−μT)TΣS−1(μS−μT)+log∣ΣT∣∣ΣS∣

5.2.3 Gram 矩阵匹配

定义 5.4(Gram 矩阵) :对于特征图 A∈RC×HW\mathbf{A} \in \mathbb{R}^{C \times HW}A∈RC×HW,Gram 矩阵定义为:

G(A)=AAT∈RC×CG(\mathbf{A}) = \mathbf{A} \mathbf{A}^T \in \mathbb{R}^{C \times C}G(A)=AAT∈RC×C

Gram 矩阵匹配损失

LGram=∑l∥G(AT(l))−G(AS(l))∥F2\mathcal{L}_{\text{Gram}} = \sum_l \|G(\mathbf{A}_T^{(l)}) - G(\mathbf{A}_S^{(l)})\|_F^2LGram=l∑∥G(AT(l))−G(AS(l))∥F2

定理 5.2(Gram 矩阵的信息含量) :Gram 矩阵 G(A)G(\mathbf{A})G(A) 编码了特征通道之间的二阶统计关系------即哪些通道倾向于同时被激活。


第六章:关系蒸馏(Relation-based Distillation)

6.1 样本间关系

6.1.1 动机

问题 :响应蒸馏和特征蒸馏都是逐样本对齐的------每个样本独立处理。但教师模型的"知识"也包含在样本间的关系中。

6.1.2 结构化蒸馏

Park et al., 2019 提出结构化知识蒸馏(Structured Knowledge Distillation)

定义 6.1(样本间距离矩阵)

Dij=∥f(xi)−f(xj)∥2D_{ij} = \|f(\mathbf{x}_i) - f(\mathbf{x}_j)\|_2Dij=∥f(xi)−f(xj)∥2

结构化蒸馏损失

Lstruct=∥DT−DS∥F2\mathcal{L}_{\text{struct}} = \|D_T - D_S\|_F^2Lstruct=∥DT−DS∥F2

其中 DTD_TDT 和 DSD_SDS 分别是教师和学生的距离矩阵。

6.1.3 关系蒸馏的理论分析

定理 6.1(距离矩阵保留类间结构) :设教师模型将样本映射到特征空间,同类样本聚集、异类样本分散。则距离矩阵 DTD_TDT 编码了类间和类内的结构信息。

证明 :设 xi\mathbf{x}ixi 和 xj\mathbf{x}jxj 属于同一类,则 DT,ijD{T,ij}DT,ij 较小;若属于不同类,则 DT,ijD{T,ij}DT,ij 较大。学生对齐 DTD_TDT 等价于学习这种类间/类内结构。□\square□

6.2 图蒸馏

6.2.1 图结构的关系

定义 6.2(关系图) :构建一个图 G=(V,E)G = (V, E)G=(V,E),其中节点是样本,边权重是样本间的相似度:

wij=exp⁡(−∥f(xi)−f(xj)∥22σ2)w_{ij} = \exp\left(-\frac{\|f(\mathbf{x}_i) - f(\mathbf{x}_j)\|^2}{2\sigma^2}\right)wij=exp(−2σ2∥f(xi)−f(xj)∥2)

图蒸馏损失

Lgraph=∥WT−WS∥F2\mathcal{L}_{\text{graph}} = \|W_T - W_S\|_F^2Lgraph=∥WT−WS∥F2

其中 WTW_TWT 和 WSW_SWS 分别是教师和学生的邻接矩阵。

6.2.2 图蒸馏的理论分析

定理 6.2(图蒸馏与核方法的关系) :图蒸馏等价于对齐教师和学生的核矩阵

KT(xi,xj)=⟨ϕT(xi),ϕT(xj)⟩K_T(\mathbf{x}_i, \mathbf{x}_j) = \langle\phi_T(\mathbf{x}_i), \phi_T(\mathbf{x}_j)\rangleKT(xi,xj)=⟨ϕT(xi),ϕT(xj)⟩

其中 ϕT\phi_TϕT 是教师的特征映射。

推论 :对齐核矩阵等价于对齐特征空间的几何结构------距离、角度、相对位置。


第三部分:自蒸馏与无教师蒸馏


第七章:自蒸馏(Self-Distillation)

7.1 核心思想

7.1.1 什么是自蒸馏?

自蒸馏(Self-Distillation):模型从自身的"旧版本"或"不同视角"学习,不需要外部教师模型。

7.1.2 自蒸馏的动机

问题:知识蒸馏需要一个大的教师模型,增加了计算成本和部署复杂度。

自蒸馏的优势

  • 不需要额外的教师模型
  • 训练成本与标准训练相当
  • 可以在训练过程中持续改进

7.2 Born-Again Networks

7.2.1 方法

Furlanello et al., 2018 提出的 Born-Again Networks(BAN)

算法 7.1(Born-Again Networks)

复制代码
1. 训练教师模型 f_T(标准训练)
2. 训练学生模型 f_S(与 f_T 架构相同),使用 f_T 作为教师
3. 将 f_S 作为新的教师,训练新的学生
4. 重复 K 次

关键发现:即使学生和教师架构相同,蒸馏后的学生模型通常优于教师模型!

7.2.2 Born-Again Networks 的理论分析

定理 7.1(BAN 的收敛性) :设第 kkk 代模型的测试准确率为 Acck\text{Acc}_kAcck。在一定假设下:

Acck+1≥Acck−ϵk\text{Acc}_{k+1} \geq \text{Acc}_k - \epsilon_kAcck+1≥Acck−ϵk

其中 ϵk→0\epsilon_k \to 0ϵk→0(随着代数增加,改善量减小)。

直觉:每一代蒸馏相当于一次"软标签增强"的训练------软标签提供了比硬标签更丰富的监督信号。

7.3 自蒸馏的变体

7.3.1 在线自蒸馏

定义 :在训练过程中,使用模型的指数移动平均(EMA) 作为教师:

θEMA(t)=β⋅θEMA(t−1)+(1−β)⋅θ(t)\theta_{\text{EMA}}^{(t)} = \beta \cdot \theta_{\text{EMA}}^{(t-1)} + (1 - \beta) \cdot \theta^{(t)}θEMA(t)=β⋅θEMA(t−1)+(1−β)⋅θ(t)

L=H(y,PS)+λ⋅DKL(PEMA∥PS)\mathcal{L} = H(\mathbf{y}, P_S) + \lambda \cdot D_{\text{KL}}(P_{\text{EMA}} \| P_S)L=H(y,PS)+λ⋅DKL(PEMA∥PS)

7.3.2 多视角自蒸馏

定义:对同一输入施加不同的增强,让不同增强的输出相互对齐:

Lself=∑i≠jDKL(PS(x(i))∥PS(x(j)))\mathcal{L}{\text{self}} = \sum{i \neq j} D_{\text{KL}}(P_S(\mathbf{x}^{(i)}) \| P_S(\mathbf{x}^{(j)}))Lself=i=j∑DKL(PS(x(i))∥PS(x(j)))

其中 x(i)\mathbf{x}^{(i)}x(i) 和 x(j)\mathbf{x}^{(j)}x(j) 是同一输入的不同增强版本。


第八章:无教师蒸馏与数据蒸馏

8.1 无教师蒸馏

8.1.1 动机

问题:在某些场景下,没有现成的教师模型------例如:

  • 模型太大无法部署为教师
  • 数据隐私限制
  • 需要从头训练

8.1.2 虚拟教师

思想 :使用集成模型多个检查点作为虚拟教师。

PT=1M∑m=1MPS(m)P_T = \frac{1}{M} \sum_{m=1}^{M} P_{S}^{(m)}PT=M1m=1∑MPS(m)

其中 PS(m)P_S^{(m)}PS(m) 是训练过程中不同检查点的输出。

8.2 数据蒸馏

8.2.1 核心思想

数据蒸馏(Dataset Distillation) :将大数据集压缩为一个小的合成数据集,使在小数据集上训练的模型能达到接近在大数据集上训练的精度。

8.2.2 形式化

问题

min⁡D~Ex∼DL(fθ(D\~)(x))\min_{\tilde{\mathcal{D}}} \mathbb{E}_{\mathbf{x} \sim \mathcal{D}} \left\\mathcal{L}(f_{\\theta(\\tilde{\\mathcal{D}})}(\\mathbf{x}))\\rightD~minEx∼DL(fθ(D\~)(x))

其中 θ(D~)=arg⁡min⁡θ∑x~∈D~L(fθ(x~))\theta(\tilde{\mathcal{D}}) = \arg\min_\theta \sum_{\tilde{\mathbf{x}} \in \tilde{\mathcal{D}}} \mathcal{L}(f_\theta(\tilde{\mathbf{x}}))θ(D~)=argminθ∑x~∈D~L(fθ(x~)) 是在合成数据集上训练的模型参数。

8.2.3 数据蒸馏的理论分析

定理 8.1(数据蒸馏的信息瓶颈) :合成数据集 D~\tilde{\mathcal{D}}D~ 的信息量受限于其大小 ∣D~∣|\tilde{\mathcal{D}}|∣D~∣:

I(D~;D)≤∣D~∣⋅log⁡∣X∣I(\tilde{\mathcal{D}}; \mathcal{D}) \leq |\tilde{\mathcal{D}}| \cdot \log |\mathcal{X}|I(D~;D)≤∣D~∣⋅log∣X∣

其中 ∣X∣|\mathcal{X}|∣X∣ 是输入空间的大小。

推论:合成数据集的大小必须足够大,才能保留原始数据集的信息。


第四部分:大语言模型的知识蒸馏


第九章:LLM 蒸馏的特殊挑战

9.1 LLM 蒸馏 vs 传统蒸馏

挑战 传统蒸馏 LLM 蒸馏
模型规模 百万-千万参数 十亿-千亿参数
输出空间 分类(几千类) 生成(词汇表 32K-100K)
序列长度 固定 可变(数百-数千)
计算成本 极高
训练数据 充足 有限或受限

9.2 LLM 蒸馏的分类

9.2.1 白盒蒸馏

白盒蒸馏 :可以访问教师模型的内部表示(logits、特征、注意力权重等)。

优势:可以利用丰富的中间信息。

代表方法:DistilBERT、TinyBERT、MiniLLM。

9.2.2 黑盒蒸馏

黑盒蒸馏 :只能访问教师模型的输出文本(通过 API)。

优势:不需要部署教师模型,适用于闭源模型。

代表方法:Alpaca、Vicuna、WizardLM。

9.3 序列级蒸馏的挑战

9.3.1 自回归生成的特殊性

问题:LLM 是自回归模型------每个 token 的生成依赖于之前的 token。这导致:

  1. 暴露偏差(Exposure Bias):训练时使用真实 token,推理时使用模型生成的 token
  2. 长程依赖:早期 token 的误差会传播到后续 token
  3. 序列长度不固定:不同样本的序列长度不同

9.3.2 序列级 KL 散度

定义 9.1(序列级 KL 散度) :对于长度为 LLL 的序列 y=(y1,...,yL)\mathbf{y} = (y_1, \dots, y_L)y=(y1,...,yL):

DKLseq(PT∥PS)=∑t=1LDKL(PT(⋅∣y<t)∥PS(⋅∣y<t))D_{\text{KL}}^{\text{seq}}(P_T \| P_S) = \sum_{t=1}^{L} D_{\text{KL}}(P_T(\cdot | y_{<t}) \| P_S(\cdot | y_{<t}))DKLseq(PT∥PS)=t=1∑LDKL(PT(⋅∣y<t)∥PS(⋅∣y<t))

其中 PT(⋅∣y<t)P_T(\cdot | y_{<t})PT(⋅∣y<t) 是教师在给定前缀 y<ty_{<t}y<t 下的下一个 token 分布。

定理 9.1(序列级 KL 的分解)

DKLseq=∑t=1L∑v∈VPT(v∣y<t)log⁡PT(v∣y<t)PS(v∣y<t)D_{\text{KL}}^{\text{seq}} = \sum_{t=1}^{L} \sum_{v \in \mathcal{V}} P_T(v | y_{<t}) \log \frac{P_T(v | y_{<t})}{P_S(v | y_{<t})}DKLseq=t=1∑Lv∈V∑PT(v∣y<t)logPS(v∣y<t)PT(v∣y<t)

计算挑战 :词汇表 V\mathcal{V}V 的大小通常为 32K-100K,每个时间步都需要计算整个词汇表的分布。


第十章:白盒蒸馏------逐层与逐 token 对齐

10.1 DistilBERT

10.1.1 方法

DistilBERT(Sanh et al., 2019)将 BERT-base(12 层)蒸馏到一个 6 层的学生模型。

损失函数

L=αLCE+βLMLM+γLcos⁡\mathcal{L} = \alpha \mathcal{L}{\text{CE}} + \beta \mathcal{L}{\text{MLM}} + \gamma \mathcal{L}_{\cos}L=αLCE+βLMLM+γLcos

其中:

  • LCE\mathcal{L}_{\text{CE}}LCE:软标签交叉熵(蒸馏损失)
  • LMLM\mathcal{L}_{\text{MLM}}LMLM:掩码语言模型损失(任务损失)
  • Lcos⁡\mathcal{L}_{\cos}Lcos:教师和学生最后一层隐藏状态的余弦相似度

10.1.2 层选择策略

问题:12 层教师如何映射到 6 层学生?

策略 1:均匀间隔 ------学生第 lll 层学习教师第 2l2l2l 层

策略 2:首尾共享------学生第 1 层和最后一层分别学习教师的第 1 层和最后一层,中间层均匀间隔

策略 3:可学习映射------学习一个映射函数,自动选择教师层

10.2 TinyBERT

10.2.1 方法

TinyBERT(Jiao et al., 2020)提出了更精细的蒸馏策略:

三阶段蒸馏

  1. 嵌入层蒸馏:对齐嵌入层的输出
  2. Transformer 层蒸馏:对齐注意力权重和隐藏状态
  3. 预测层蒸馏:对齐最终输出

损失函数

L=∑l∈SλlLlayer(fT(m(l)),fS(l))\mathcal{L} = \sum_{l \in \mathcal{S}} \lambda_l \mathcal{L}_{\text{layer}}(f_T^{(m(l))}, f_S^{(l)})L=l∈S∑λlLlayer(fT(m(l)),fS(l))

其中 S\mathcal{S}S 是选中的层集合,m(l)m(l)m(l) 是学生第 lll 层对应的教师层。

10.2.2 注意力蒸馏

定义 10.1(注意力蒸馏损失)

Lattn=1h∑i=1h∥AT(i)−AS(i)∥F2\mathcal{L}{\text{attn}} = \frac{1}{h} \sum{i=1}^{h} \|A_T^{(i)} - A_S^{(i)}\|_F^2Lattn=h1i=1∑h∥AT(i)−AS(i)∥F2

其中 AT(i)A_T^{(i)}AT(i) 和 AS(i)A_S^{(i)}AS(i) 分别是教师和学生第 iii 个注意力头的注意力权重矩阵。

10.2.3 隐藏状态蒸馏

定义 10.2(隐藏状态蒸馏损失)

Lhidden=∥M⋅HS−HT∥F2\mathcal{L}_{\text{hidden}} = \|M \cdot H_S - H_T\|_F^2Lhidden=∥M⋅HS−HT∥F2

其中 MMM 是一个可学习的线性变换,用于维度匹配。

10.3 MiniLLM

10.3.1 动机

问题 :标准 KL 蒸馏在 LLM 中会导致长度偏差------学生倾向于生成更短的序列。

10.3.2 反向 KL 蒸馏

MiniLLM (Gu et al., 2024)提出使用反向 KL 散度

LMiniLLM=DKL(PS∥PT)=∑t∑vPS(v∣y<t)log⁡PS(v∣y<t)PT(v∣y<t)\mathcal{L}{\text{MiniLLM}} = D{\text{KL}}(P_S \| P_T) = \sum_t \sum_v P_S(v | y_{<t}) \log \frac{P_S(v | y_{<t})}{P_T(v | y_{<t})}LMiniLLM=DKL(PS∥PT)=t∑v∑PS(v∣y<t)logPT(v∣y<t)PS(v∣y<t)

定理 10.1(反向 KL 避免长度偏差) :前向 KL DKL(PT∥PS)D_{\text{KL}}(P_T \| P_S)DKL(PT∥PS) 倾向于覆盖教师分布的所有模式(mode-covering),导致学生分配概率给低概率 token,从而在生成时倾向于选择"安全"的短序列。

反向 KL DKL(PS∥PT)D_{\text{KL}}(P_S \| P_T)DKL(PS∥PT) 倾向于集中在教师分布的一个模式上(mode-seeking),鼓励学生生成更确定性的输出,减少长度偏差。


第十一章:黑盒蒸馏------从 API 到合成数据

11.1 黑盒蒸馏的动机

11.1.1 闭源模型的挑战

问题:许多强大的 LLM(如 GPT-4、Claude)只提供 API 访问,无法获取内部表示。

黑盒蒸馏 :只使用教师模型的文本输出来进行蒸馏。

11.1.2 Alpaca 方法

Alpaca(Taori et al., 2023)是最早的黑盒蒸馏工作之一:

  1. 使用 GPT-3.5 生成 52K 条指令-回复对
  2. 使用这些数据微调 LLaMA-7B

损失函数:标准的自回归交叉熵(只在回复部分计算):

L=−∑t∈responselog⁡PS(yt∣y<t,x)\mathcal{L} = -\sum_{t \in \text{response}} \log P_S(y_t | y_{<t}, \mathbf{x})L=−t∈response∑logPS(yt∣y<t,x)

11.1.3 Vicuna 方法

Vicuna(Chiang et al., 2023)使用 ShareGPT 数据(用户与 ChatGPT 的真实对话)来微调 LLaMA。

关键差异 :Alpaca 使用单轮 指令数据,Vicuna 使用多轮对话数据。

11.2 合成数据生成

11.2.1 Self-Instruct

算法 11.1(Self-Instruct)

复制代码
1. 种子任务集:S = {少量人工编写的指令-回复对}
2. 重复 N 次:
   a. 从 S 中采样几个示例
   b. 使用 LLM 生成新的指令
   c. 使用 LLM 生成对应的回复
   d. 过滤低质量数据
   e. 将新数据加入 S
3. 使用 S 微调学生模型

11.2.2 Evol-Instruct

WizardLM (Xu et al., 2023)提出进化指令

算法 11.2(Evol-Instruct)

复制代码
1. 初始指令集:I = {简单指令}
2. 重复:
   a. 从 I 中选择一个指令
   b. 使用 LLM 将指令"进化"为更复杂的版本:
      - 增加约束条件
      - 增加推理步骤
      - 增加领域知识
   c. 使用 LLM 生成回复
   d. 将新的指令-回复对加入 I

11.3 黑盒蒸馏的理论分析

11.3.1 数据质量 vs 数量

定理 11.1(数据效率) :设教师模型的输出质量为 QTQ_TQT,学生模型在 NNN 条数据上训练后的性能为:

Acc(N)=QT⋅(1−e−λN)\text{Acc}(N) = Q_T \cdot (1 - e^{-\lambda N})Acc(N)=QT⋅(1−e−λN)

其中 λ\lambdaλ 是数据效率参数。

推论:少量高质量数据可能优于大量低质量数据。

11.3.2 教师能力的上界

定理 11.2(蒸馏的上界):学生模型的性能不可能超过教师模型(在相同任务上):

AccS≤AccT\text{Acc}_S \leq \text{Acc}_TAccS≤AccT

但实践中的例外

  • 学生模型可能在某些特定子任务上超越教师
  • 数据增强和正则化可以弥补容量差距
  • 教师的"错误"可能被学生修正

第十二章:思维链蒸馏与推理能力迁移

12.1 思维链蒸馏

12.1.1 动机

问题:标准蒸馏只迁移了教师的"答案",但没有迁移教师的"推理过程"。

思维链(Chain-of-Thought, CoT):教师模型在生成答案之前,先生成推理步骤。

12.1.2 CoT 蒸馏方法

算法 12.1(CoT 蒸馏)

复制代码
1. 使用教师模型生成 (问题, 思维链, 答案) 三元组
2. 训练学生模型学习生成完整的思维链:
   P_S(CoT, answer | question)
3. 损失函数:
   L = -Σ_t log P_S(token_t | token_{<t}, question)
   (对思维链和答案的所有 token 计算)

12.1.3 CoT 蒸馏的理论分析

定理 12.1(CoT 蒸馏的信息增益):相比只学习答案,CoT 蒸馏提供了额外的监督信号:

ICoT=I(answer;CoT∣question)I_{\text{CoT}} = I(\text{answer}; \text{CoT} | \text{question})ICoT=I(answer;CoT∣question)

这个互信息衡量了思维链中包含的关于答案的额外信息。

推论:对于需要推理的任务(如数学、逻辑),CoT 蒸馏的效果显著优于标准蒸馏。

12.2 推理能力迁移

12.2.1 迁移的挑战

问题:大模型的推理能力(如多步推理、类比推理)是否可以迁移到小模型?

实验观察

  • 标准蒸馏:小模型可以学习到教师的"知识",但推理能力有限
  • CoT 蒸馏:小模型可以学习到教师的"推理模式",但泛化能力有限
  • 迭代蒸馏:通过多轮蒸馏,可以逐步提升小模型的推理能力

12.2.2 迭代 CoT 蒸馏

算法 12.2(迭代 CoT 蒸馏)

复制代码
for round = 1, ..., R:
    1. 使用教师生成 CoT 数据
    2. 使用 CoT 数据训练学生
    3. 使用学生作为新的教师(或使用学生生成新的 CoT 数据)
    4. 评估学生的推理能力

第五部分:完整可运行代码实现


第十三章:从零实现经典知识蒸馏

python 复制代码
"""
经典知识蒸馏的完整实现。
包含:Hinton 蒸馏、温度缩放、各种损失函数。
"""

import numpy as np
from typing import Tuple, Optional


def softmax(z: np.ndarray, temperature: float = 1.0) -> np.ndarray:
    """温度缩放的 softmax。

    Args:
        z: logits (..., K)
        temperature: 温度参数

    Returns:
        p: 概率分布 (..., K)
    """
    z_scaled = z / temperature
    z_max = np.max(z_scaled, axis=-1, keepdims=True)
    exp_z = np.exp(z_scaled - z_max)
    return exp_z / np.sum(exp_z, axis=-1, keepdims=True)


def kl_divergence(p: np.ndarray, q: np.ndarray) -> float:
    """KL 散度 D_KL(p || q)。

    Args:
        p: 概率分布 (..., K)
        q: 概率分布 (..., K)

    Returns:
        kl: KL 散度
    """
    # 避免 log(0)
    p = np.clip(p, 1e-10, 1.0)
    q = np.clip(q, 1e-10, 1.0)
    return np.sum(p * np.log(p / q), axis=-1)


def cross_entropy(p: np.ndarray, q: np.ndarray) -> float:
    """交叉熵 H(p, q)。

    Args:
        p: 目标分布 (..., K)
        q: 预测分布 (..., K)

    Returns:
        ce: 交叉熵
    """
    q = np.clip(q, 1e-10, 1.0)
    return -np.sum(p * np.log(q), axis=-1)


def hinton_distillation_loss(
    teacher_logits: np.ndarray,
    student_logits: np.ndarray,
    labels: np.ndarray,
    temperature: float = 4.0,
    alpha: float = 0.1,
) -> Tuple[float, dict]:
    """Hinton 蒸馏损失。

    L = α * L_hard + (1-α) * T^2 * L_soft

    Args:
        teacher_logits: 教师 logits (B, K)
        student_logits: 学生 logits (B, K)
        labels: 真实标签 (B,)
        temperature: 温度参数
        alpha: 硬标签权重

    Returns:
        loss: 总损失
        info: 损失详情
    """
    B, K = teacher_logits.shape

    # 软标签
    p_teacher = softmax(teacher_logits, temperature)
    p_student = softmax(student_logits, temperature)

    # 硬标签
    y_onehot = np.zeros((B, K))
    y_onehot[np.arange(B), labels] = 1
    p_student_hard = softmax(student_logits, temperature=1.0)

    # 软标签损失(KL 散度)
    L_soft = np.mean(kl_divergence(p_teacher, p_student))

    # 硬标签损失(交叉熵)
    L_hard = np.mean(cross_entropy(y_onehot, p_student_hard))

    # 总损失
    loss = alpha * L_hard + (1 - alpha) * temperature**2 * L_soft

    info = {
        "L_hard": L_hard,
        "L_soft": L_soft,
        "total": loss,
        "teacher_entropy": np.mean(-np.sum(p_teacher * np.log(np.clip(p_teacher, 1e-10, 1.0)), axis=-1)),
        "student_entropy": np.mean(-np.sum(p_student * np.log(np.clip(p_student, 1e-10, 1.0)), axis=-1)),
    }

    return loss, info


def mae_distillation_loss(
    teacher_logits: np.ndarray,
    student_logits: np.ndarray,
) -> float:
    """MSE 蒸馏损失(logits 级别)。

    Args:
        teacher_logits: 教师 logits
        student_logits: 学生 logits

    Returns:
        loss: MSE 损失
    """
    return np.mean((teacher_logits - student_logits) ** 2)


def cosine_distillation_loss(
    teacher_features: np.ndarray,
    student_features: np.ndarray,
) -> float:
    """余弦相似度蒸馏损失。

    Args:
        teacher_features: 教师特征 (B, D)
        student_features: 学生特征 (B, D)

    Returns:
        loss: 1 - 余弦相似度
    """
    # 归一化
    t_norm = teacher_features / (np.linalg.norm(teacher_features, axis=1, keepdims=True) + 1e-8)
    s_norm = student_features / (np.linalg.norm(student_features, axis=1, keepdims=True) + 1e-8)

    # 余弦相似度
    cos_sim = np.sum(t_norm * s_norm, axis=1)

    return np.mean(1 - cos_sim)


def demonstrate_distillation():
    """演示知识蒸馏。"""
    np.random.seed(42)

    print("=" * 70)
    print("知识蒸馏基础演示")
    print("=" * 70)

    # 设置
    B = 32  # batch size
    K = 10  # 类别数

    # 生成模拟数据
    teacher_logits = np.random.randn(B, K) * 2
    student_logits = np.random.randn(B, K) * 1.5
    labels = np.random.randint(0, K, B)

    # 温度的影响
    print("\n  1. 温度对软标签的影响")
    print("  " + "-" * 40)

    for T in [1, 2, 4, 8, 16]:
        p_teacher = softmax(teacher_logits, T)
        entropy = np.mean(-np.sum(p_teacher * np.log(np.clip(p_teacher, 1e-10, 1.0)), axis=-1))
        max_prob = np.mean(np.max(p_teacher, axis=1))
        print(f"    T={T:2d}: 熵={entropy:.4f}, 最大概率={max_prob:.4f}")

    # 蒸馏损失
    print("\n  2. 不同温度下的蒸馏损失")
    print("  " + "-" * 40)

    for T in [1, 2, 4, 8, 16]:
        loss, info = hinton_distillation_loss(teacher_logits, student_logits, labels, temperature=T)
        print(f"    T={T:2d}: 总损失={info['total']:.4f}, "
              f"硬标签={info['L_hard']:.4f}, 软标签={info['L_soft']:.4f}")

    # alpha 的影响
    print("\n  3. 不同 α 值的影响")
    print("  " + "-" * 40)

    T = 4
    for alpha in [0.0, 0.1, 0.3, 0.5, 0.7, 0.9]:
        loss, info = hinton_distillation_loss(teacher_logits, student_logits, labels, temperature=T, alpha=alpha)
        print(f"    α={alpha:.1f}: 总损失={info['total']:.4f}")

    # 各种蒸馏损失对比
    print("\n  4. 各种蒸馏损失对比")
    print("  " + "-" * 40)

    # Hinton 蒸馏
    loss_hinton, _ = hinton_distillation_loss(teacher_logits, student_logits, labels, temperature=4, alpha=0.1)

    # MSE 蒸馏
    loss_mse = mae_distillation_loss(teacher_logits, student_logits)

    # 余弦蒸馏
    loss_cos = cosine_distillation_loss(teacher_logits, student_logits)

    print(f"    Hinton (T=4, α=0.1): {loss_hinton:.4f}")
    print(f"    MSE (logits):         {loss_mse:.4f}")
    print(f"    Cosine:               {loss_cos:.4f}")

    # 教师-学生差异分析
    print("\n  5. 教师-学生差异分析")
    print("  " + "-" * 40)

    p_t = softmax(teacher_logits, T=4)
    p_s = softmax(student_logits, T=4)

    # KL 散度
    kl_ts = np.mean(kl_divergence(p_t, p_s))
    kl_st = np.mean(kl_divergence(p_s, p_t))
    js = np.mean(0.5 * kl_divergence(p_t, 0.5 * (p_t + p_s)) + 0.5 * kl_divergence(p_s, 0.5 * (p_t + p_s)))

    print(f"    D_KL(teacher||student): {kl_ts:.4f}")
    print(f"    D_KL(student||teacher): {kl_st:.4f}")
    print(f"    D_JS(teacher||student): {js:.4f}")

    # 预测一致性
    teacher_pred = np.argmax(teacher_logits, axis=1)
    student_pred = np.argmax(student_logits, axis=1)
    agreement = np.mean(teacher_pred == student_pred)

    print(f"    预测一致率: {agreement:.2%}")


if __name__ == "__main__":
    demonstrate_distillation()

第十四章:从零实现特征蒸馏

python 复制代码
"""
特征蒸馏的完整实现。
包含:FitNets、注意力转移、特征分布对齐。
"""

import numpy as np
from typing import List, Tuple


def feature_distillation_loss(
    teacher_features: np.ndarray,
    student_features: np.ndarray,
    method: str = "mse",
) -> float:
    """特征蒸馏损失。

    Args:
        teacher_features: 教师特征 (B, D_t)
        student_features: 学生特征 (B, D_s)
        method: 方法 ("mse", "cosine", "normalize")

    Returns:
        loss: 蒸馏损失
    """
    if method == "mse":
        # MSE 损失
        return np.mean((teacher_features - student_features) ** 2)

    elif method == "cosine":
        # 余弦相似度损失
        t_norm = teacher_features / (np.linalg.norm(teacher_features, axis=1, keepdims=True) + 1e-8)
        s_norm = student_features / (np.linalg.norm(student_features, axis=1, keepdims=True) + 1e-8)
        cos_sim = np.sum(t_norm * s_norm, axis=1)
        return np.mean(1 - cos_sim)

    elif method == "normalize":
        # 归一化 MSE
        t_mean, t_std = teacher_features.mean(axis=0), teacher_features.std(axis=0) + 1e-8
        s_mean, s_std = student_features.mean(axis=0), student_features.std(axis=0) + 1e-8

        t_norm = (teacher_features - t_mean) / t_std
        s_norm = (student_features - s_mean) / s_std

        return np.mean((t_norm - s_norm) ** 2)

    else:
        raise ValueError(f"Unknown method: {method}")


def attention_transfer_loss(
    teacher_attention: np.ndarray,
    student_attention: np.ndarray,
) -> float:
    """注意力转移损失。

    Args:
        teacher_attention: 教师注意力图 (B, H, S, S)
        student_attention: 学生注意力图 (B, H', S, S)

    Returns:
        loss: 注意力转移损失
    """
    # 计算注意力图(对 head 维度取平均)
    A_teacher = np.mean(np.abs(teacher_attention), axis=1)  # (B, S, S)
    A_student = np.mean(np.abs(student_attention), axis=1)  # (B, S, S)

    # 归一化
    A_teacher_norm = A_teacher / (np.linalg.norm(A_teacher.reshape(A_teacher.shape[0], -1), axis=1, keepdims=True).reshape(-1, 1, 1) + 1e-8)
    A_student_norm = A_student / (np.linalg.norm(A_student.reshape(A_student.shape[0], -1), axis=1, keepdims=True).reshape(-1, 1, 1) + 1e-8)

    # MSE 损失
    loss = np.mean((A_teacher_norm - A_student_norm) ** 2)

    return loss


def gram_matrix(features: np.ndarray) -> np.ndarray:
    """计算 Gram 矩阵。

    Args:
        features: 特征图 (B, C, H, W)

    Returns:
        G: Gram 矩阵 (B, C, C)
    """
    B, C, H, W = features.shape
    F = features.reshape(B, C, H * W)
    G = np.matmul(F, F.transpose(0, 2, 1))
    return G / (C * H * W)


def gram_distillation_loss(
    teacher_features: np.ndarray,
    student_features: np.ndarray,
) -> float:
    """Gram 矩阵蒸馏损失。

    Args:
        teacher_features: 教师特征图 (B, C_t, H, W)
        student_features: 学生特征图 (B, C_s, H, W)

    Returns:
        loss: Gram 矩阵匹配损失
    """
    G_teacher = gram_matrix(teacher_features)
    G_student = gram_matrix(student_features)

    # 如果通道数不同,需要匹配
    if G_teacher.shape[1] != G_student.shape[1]:
        min_channels = min(G_teacher.shape[1], G_student.shape[1])
        G_teacher = G_teacher[:, :min_channels, :min_channels]
        G_student = G_student[:, :min_channels, :min_channels]

    return np.mean((G_teacher - G_student) ** 2)


def distribution_alignment_loss(
    teacher_features: np.ndarray,
    student_features: np.ndarray,
) -> float:
    """特征分布对齐损失。

    对齐均值和方差。

    Args:
        teacher_features: 教师特征 (B, D)
        student_features: 学生特征 (B, D)

    Returns:
        loss: 分布对齐损失
    """
    # 均值对齐
    t_mean = np.mean(teacher_features, axis=0)
    s_mean = np.mean(student_features, axis=0)
    mean_loss = np.mean((t_mean - s_mean) ** 2)

    # 方差对齐
    t_var = np.var(teacher_features, axis=0) + 1e-8
    s_var = np.var(student_features, axis=0) + 1e-8
    var_loss = np.mean((np.sqrt(t_var) - np.sqrt(s_var)) ** 2)

    return mean_loss + var_loss


def demonstrate_feature_distillation():
    """演示特征蒸馏。"""
    np.random.seed(42)

    print("=" * 70)
    print("特征蒸馏演示")
    print("=" * 70)

    B = 16
    D_t = 256  # 教师特征维度
    D_s = 128  # 学生特征维度

    # 生成特征
    teacher_features = np.random.randn(B, D_t) * 0.5
    student_features = np.random.randn(B, D_s) * 0.5

    # 不同蒸馏方法
    print("\n  1. 特征蒸馏损失对比")
    print("  " + "-" * 40)

    # MSE(需要维度匹配)
    teacher_matched = teacher_features[:, :D_s]  # 简单截断
    loss_mse = feature_distillation_loss(teacher_matched, student_features, "mse")
    loss_cos = feature_distillation_loss(teacher_matched, student_features, "cosine")
    loss_norm = feature_distillation_loss(teacher_matched, student_features, "normalize")

    print(f"    MSE:           {loss_mse:.6f}")
    print(f"    Cosine:        {loss_cos:.6f}")
    print(f"    Normalize MSE: {loss_norm:.6f}")

    # 注意力转移
    print("\n  2. 注意力转移损失")
    print("  " + "-" * 40)

    S = 32  # 序列长度
    H_t = 8  # 教师头数
    H_s = 4  # 学生头数

    teacher_attn = np.random.randn(B, H_t, S, S) * 0.1
    student_attn = np.random.randn(B, H_s, S, S) * 0.1

    # 应用 softmax 使其成为有效的注意力权重
    teacher_attn = np.exp(teacher_attn) / np.sum(np.exp(teacher_attn), axis=-1, keepdims=True)
    student_attn = np.exp(student_attn) / np.sum(np.exp(student_attn), axis=-1, keepdims=True)

    loss_attn = attention_transfer_loss(teacher_attn, student_attn)
    print(f"    注意力转移损失: {loss_attn:.6f}")

    # Gram 矩阵蒸馏
    print("\n  3. Gram 矩阵蒸馏损失")
    print("  " + "-" * 40)

    C_t = 64
    C_s = 32
    H, W = 8, 8

    teacher_feat_map = np.random.randn(B, C_t, H, W) * 0.5
    student_feat_map = np.random.randn(B, C_s, H, W) * 0.5

    loss_gram = gram_distillation_loss(teacher_feat_map, student_feat_map)
    print(f"    Gram 矩阵损失: {loss_gram:.6f}")

    # 分布对齐
    print("\n  4. 分布对齐损失")
    print("  " + "-" * 40)

    loss_dist = distribution_alignment_loss(teacher_matched, student_features)
    print(f"    分布对齐损失: {loss_dist:.6f}")


if __name__ == "__main__":
    demonstrate_feature_distillation()

第十五章:从零实现自蒸馏

python 复制代码
"""
自蒸馏的完整实现。
包含:Born-Again Networks、在线自蒸馏、多视角自蒸馏。
"""

import numpy as np
from typing import List, Tuple


class SimpleModel:
    """简单的线性模型,用于演示。"""

    def __init__(self, input_dim: int, output_dim: int, lr: float = 0.01):
        self.W = np.random.randn(input_dim, output_dim) * 0.01
        self.b = np.zeros(output_dim)
        self.lr = lr

    def forward(self, X: np.ndarray) -> np.ndarray:
        """前向传播。"""
        return X @ self.W + self.b

    def softmax(self, z: np.ndarray, temperature: float = 1.0) -> np.ndarray:
        """温度缩放的 softmax。"""
        z_scaled = z / temperature
        z_max = np.max(z_scaled, axis=-1, keepdims=True)
        exp_z = np.exp(z_scaled - z_max)
        return exp_z / np.sum(exp_z, axis=-1, keepdims=True)

    def train_step(
        self,
        X: np.ndarray,
        y: np.ndarray,
        teacher_logits: np.ndarray = None,
        temperature: float = 4.0,
        alpha: float = 0.5,
    ) -> float:
        """训练步骤。

        Args:
            X: 输入 (B, D)
            y: 标签 (B,)
            teacher_logits: 教师 logits (如果有的话)
            temperature: 温度
            alpha: 硬标签权重

        Returns:
            loss: 损失
        """
        B = X.shape[0]
        K = self.W.shape[1]

        # 前向传播
        logits = self.forward(X)

        # 硬标签
        y_onehot = np.zeros((B, K))
        y_onehot[np.arange(B), y] = 1
        p_student_hard = self.softmax(logits, 1.0)

        # 硬标签梯度
        grad_hard = (p_student_hard - y_onehot) / B

        if teacher_logits is not None:
            # 软标签
            p_teacher = self.softmax(teacher_logits, temperature)
            p_student = self.softmax(logits, temperature)

            # 软标签梯度(近似)
            grad_soft = temperature * (p_student - p_teacher) / B

            # 组合梯度
            grad = alpha * grad_hard + (1 - alpha) * grad_soft
        else:
            grad = grad_hard

        # 更新参数
        self.W -= self.lr * X.T @ grad
        self.b -= self.lr * grad.mean(axis=0)

        # 计算损失
        p = self.softmax(logits, 1.0)
        loss = -np.mean(np.sum(y_onehot * np.log(np.clip(p, 1e-10, 1.0)), axis=-1))

        return loss

    def predict(self, X: np.ndarray) -> np.ndarray:
        """预测。"""
        logits = self.forward(X)
        return np.argmax(logits, axis=1)


def born_again_networks(
    X_train: np.ndarray,
    y_train: np.ndarray,
    X_test: np.ndarray,
    y_test: np.ndarray,
    n_generations: int = 3,
    n_epochs: int = 50,
    temperature: float = 4.0,
    alpha: float = 0.1,
) -> List[dict]:
    """Born-Again Networks 蒸馏。

    Args:
        X_train, y_train: 训练数据
        X_test, y_test: 测试数据
        n_generations: 代数
        n_epochs: 每代的训练轮数
        temperature: 温度
        alpha: 硬标签权重

    Returns:
        results: 每代的结果
    """
    D = X_train.shape[1]
    K = len(np.unique(y_train))
    results = []

    teacher_logits = None

    for gen in range(n_generations):
        print(f"\n  第 {gen + 1} 代:")

        # 创建学生模型
        student = SimpleModel(D, K, lr=0.01)

        # 训练
        for epoch in range(n_epochs):
            loss = student.train_step(
                X_train, y_train,
                teacher_logits=teacher_logits,
                temperature=temperature,
                alpha=alpha if teacher_logits is not None else 1.0,
            )

        # 评估
        train_pred = student.predict(X_train)
        test_pred = student.predict(X_test)
        train_acc = np.mean(train_pred == y_train)
        test_acc = np.mean(test_pred == y_test)

        print(f"    训练准确率: {train_acc:.2%}")
        print(f"    测试准确率: {test_acc:.2%}")

        results.append({
            "generation": gen + 1,
            "train_acc": train_acc,
            "test_acc": test_acc,
        })

        # 将学生作为下一代的教师
        teacher_logits = student.forward(X_train)

    return results


def online_self_distillation(
    X_train: np.ndarray,
    y_train: np.ndarray,
    X_test: np.ndarray,
    y_test: np.ndarray,
    n_epochs: int = 100,
    ema_decay: float = 0.999,
    temperature: float = 4.0,
    distill_weight: float = 0.5,
) -> dict:
    """在线自蒸馏(EMA 教师)。

    Args:
        X_train, y_train: 训练数据
        X_test, y_test: 测试数据
        n_epochs: 训练轮数
        ema_decay: EMA 衰减率
        temperature: 温度
        distill_weight: 蒸馏权重

    Returns:
        results: 训练结果
    """
    D = X_train.shape[1]
    K = len(np.unique(y_train))

    # 主模型
    model = SimpleModel(D, K, lr=0.01)

    # EMA 模型(教师)
    ema_model = SimpleModel(D, K, lr=0.0)  # 不更新
    ema_model.W = model.W.copy()
    ema_model.b = model.b.copy()

    train_accs = []
    test_accs = []

    for epoch in range(n_epochs):
        # 训练主模型(使用 EMA 模型作为教师)
        ema_logits = ema_model.forward(X_train)
        loss = model.train_step(
            X_train, y_train,
            teacher_logits=ema_logits,
            temperature=temperature,
            alpha=1 - distill_weight,
        )

        # 更新 EMA 模型
        ema_model.W = ema_decay * ema_model.W + (1 - ema_decay) * model.W
        ema_model.b = ema_decay * ema_model.b + (1 - ema_decay) * model.b

        # 评估
        if (epoch + 1) % 10 == 0:
            train_pred = model.predict(X_train)
            test_pred = model.predict(X_test)
            train_acc = np.mean(train_pred == y_train)
            test_acc = np.mean(test_pred == y_test)
            train_accs.append(train_acc)
            test_accs.append(test_acc)

            print(f"  Epoch {epoch + 1}: 训练={train_acc:.2%}, 测试={test_acc:.2%}")

    return {
        "train_accs": train_accs,
        "test_accs": test_accs,
        "final_train_acc": train_accs[-1],
        "final_test_acc": test_accs[-1],
    }


def demonstrate_self_distillation():
    """演示自蒸馏。"""
    np.random.seed(42)

    print("=" * 70)
    print("自蒸馏演示")
    print("=" * 70)

    # 生成数据
    n_train = 500
    n_test = 100
    D = 32
    K = 5

    X_train = np.random.randn(n_train, D)
    y_train = np.random.randint(0, K, n_train)
    X_test = np.random.randn(n_test, D)
    y_test = np.random.randint(0, K, n_test)

    # Born-Again Networks
    print("\n  1. Born-Again Networks")
    print("  " + "-" * 40)

    ban_results = born_again_networks(
        X_train, y_train, X_test, y_test,
        n_generations=3, n_epochs=50, temperature=4.0, alpha=0.1,
    )

    print(f"\n  BAN 结果汇总:")
    for r in ban_results:
        print(f"    第 {r['generation']} 代: 测试准确率 = {r['test_acc']:.2%}")

    # 在线自蒸馏
    print("\n\n  2. 在线自蒸馏 (EMA 教师)")
    print("  " + "-" * 40)

    online_results = online_self_distillation(
        X_train, y_train, X_test, y_test,
        n_epochs=100, ema_decay=0.99, temperature=4.0, distill_weight=0.3,
    )

    print(f"\n  在线自蒸馏最终结果:")
    print(f"    训练准确率: {online_results['final_train_acc']:.2%}")
    print(f"    测试准确率: {online_results['final_test_acc']:.2%}")

    # 基线对比
    print("\n\n  3. 基线对比(无蒸馏)")
    print("  " + "-" * 40)

    baseline = SimpleModel(D, K, lr=0.01)
    for epoch in range(100):
        baseline.train_step(X_train, y_train)

    train_pred = baseline.predict(X_train)
    test_pred = baseline.predict(X_test)
    print(f"    训练准确率: {np.mean(train_pred == y_train):.2%}")
    print(f"    测试准确率: {np.mean(test_pred == y_test):.2%}")


if __name__ == "__main__":
    demonstrate_self_distillation()

第十六章:完整蒸馏 Pipeline 与精度对比

python 复制代码
"""
完整的蒸馏 Pipeline。
对比各种蒸馏方法在不同设置下的效果。
"""

import numpy as np
from typing import Dict, List, Tuple


def softmax(z, T=1.0):
    z = z / T
    z_max = np.max(z, axis=-1, keepdims=True)
    exp_z = np.exp(z - z_max)
    return exp_z / np.sum(exp_z, axis=-1, keepdims=True)


def run_full_distillation_comparison():
    """运行完整的蒸馏方法对比。"""
    np.random.seed(42)

    print("=" * 70)
    print("知识蒸馏方法综合对比")
    print("=" * 70)

    # 生成数据
    n_train = 1000
    n_test = 200
    D = 64
    K = 10

    X_train = np.random.randn(n_train, D)
    y_train = np.random.randint(0, K, n_train)
    X_test = np.random.randn(n_test, D)
    y_test = np.random.randint(0, K, n_test)

    # 教师模型(更大的模型)
    teacher_W = np.random.randn(D, K) * 0.1
    teacher_b = np.zeros(K)

    # 教师预测
    teacher_logits_train = X_train @ teacher_W + teacher_b
    teacher_logits_test = X_test @ teacher_W + teacher_b
    teacher_pred = np.argmax(teacher_logits_test, axis=1)
    teacher_acc = np.mean(teacher_pred == y_test)

    print(f"\n  教师模型准确率: {teacher_acc:.2%}")
    print(f"  教师参数量: {D * K + K}")

    # 定义训练函数
    def train_student(
        X_train, y_train, teacher_logits=None,
        temperature=4.0, alpha=0.5, n_epochs=100, lr=0.01
    ):
        D = X_train.shape[1]
        K = len(np.unique(y_train))
        W = np.random.randn(D, K) * 0.01
        b = np.zeros(K)

        for epoch in range(n_epochs):
            logits = X_train @ W + b
            B = X_train.shape[0]

            # 硬标签
            y_onehot = np.zeros((B, K))
            y_onehot[np.arange(B), y_train] = 1
            p_hard = softmax(logits, 1.0)
            grad_hard = (p_hard - y_onehot) / B

            if teacher_logits is not None:
                p_teacher = softmax(teacher_logits, temperature)
                p_student = softmax(logits, temperature)
                grad_soft = temperature * (p_student - p_teacher) / B
                grad = alpha * grad_hard + (1 - alpha) * grad_soft
            else:
                grad = grad_hard

            W -= lr * X_train.T @ grad
            b -= lr * grad.mean(axis=0)

        return W, b

    # 测试不同方法
    methods = {}

    # 1. 无蒸馏(基线)
    W, b = train_student(X_train, y_train, n_epochs=100)
    pred = np.argmax(X_test @ W + b, axis=1)
    methods["无蒸馏(基线)"] = np.mean(pred == y_test)

    # 2. 标准蒸馏
    for T in [2, 4, 8]:
        for alpha in [0.1, 0.5]:
            W, b = train_student(
                X_train, y_train,
                teacher_logits=teacher_logits_train,
                temperature=T, alpha=alpha, n_epochs=100
            )
            pred = np.argmax(X_test @ W + b, axis=1)
            name = f"蒸馏 T={T}, α={alpha}"
            methods[name] = np.mean(pred == y_test)

    # 3. MSE 蒸馏
    def train_mse(X_train, y_train, teacher_logits, n_epochs=100, lr=0.01, distill_weight=0.5):
        D = X_train.shape[1]
        K = teacher_logits.shape[1]
        W = np.random.randn(D, K) * 0.01
        b = np.zeros(K)

        for epoch in range(n_epochs):
            logits = X_train @ W + b
            B = X_train.shape[0]

            y_onehot = np.zeros((B, K))
            y_onehot[np.arange(B), y_train] = 1
            p_hard = softmax(logits, 1.0)
            grad_hard = (p_hard - y_onehot) / B

            grad_mse = 2 * (logits - teacher_logits) / B
            grad = (1 - distill_weight) * grad_hard + distill_weight * grad_mse

            W -= lr * X_train.T @ grad
            b -= lr * grad.mean(axis=0)

        return W, b

    W, b = train_mse(X_train, y_train, teacher_logits_train, distill_weight=0.5)
    pred = np.argmax(X_test @ W + b, axis=1)
    methods["MSE 蒸馏"] = np.mean(pred == y_test)

    # 打印结果
    print(f"\n  {'方法':>25} {'测试准确率':>12} {'相对提升':>12}")
    print(f"  {'-'*25} {'-'*12} {'-'*12}")

    baseline_acc = methods["无蒸馏(基线)"]
    for name, acc in methods.items():
        improvement = (acc - baseline_acc) / baseline_acc * 100
        print(f"  {name:>25} {acc:>12.2%} {improvement:>+11.1f}%")

    # 最佳方法
    best_method = max(methods.items(), key=lambda x: x[1])
    print(f"\n  最佳方法: {best_method[0]} (准确率: {best_method[1]:.2%})")

    # 温度敏感性分析
    print(f"\n  温度敏感性分析 (α=0.1):")
    print(f"  {'温度':>8} {'准确率':>12} {'教师熵':>12} {'学生熵':>12}")
    print(f"  {'-'*8} {'-'*12} {'-'*12} {'-'*12}")

    for T in [1, 2, 4, 8, 16, 32]:
        W, b = train_student(
            X_train, y_train,
            teacher_logits=teacher_logits_train,
            temperature=T, alpha=0.1, n_epochs=100
        )
        pred = np.argmax(X_test @ W + b, axis=1)
        acc = np.mean(pred == y_test)

        p_t = softmax(teacher_logits_train, T)
        p_s = softmax(X_train @ W + b, T)
        t_entropy = np.mean(-np.sum(p_t * np.log(np.clip(p_t, 1e-10, 1.0)), axis=-1))
        s_entropy = np.mean(-np.sum(p_s * np.log(np.clip(p_s, 1e-10, 1.0)), axis=-1))

        print(f"  {T:>8} {acc:>12.2%} {t_entropy:>12.4f} {s_entropy:>12.4f}")


if __name__ == "__main__":
    run_full_distillation_comparison()

附录:关键公式汇总

A.1 散度与距离

公式 表达式
KL 散度 DKL(P∣Q)=∑iP(i)log⁡P(i)Q(i)D_{\text{KL}}(P | Q) = \sum_i P(i) \log \frac{P(i)}{Q(i)}DKL(P∣Q)=∑iP(i)logQ(i)P(i)
交叉熵 H(P,Q)=−∑iP(i)log⁡Q(i)H(P, Q) = -\sum_i P(i) \log Q(i)H(P,Q)=−∑iP(i)logQ(i)
JS 散度 DJS(P∣Q)=12DKL(P∣M)+12DKL(Q∣M)D_{\text{JS}}(P | Q) = \frac{1}{2} D_{\text{KL}}(P | M) + \frac{1}{2} D_{\text{KL}}(Q | M)DJS(P∣Q)=21DKL(P∣M)+21DKL(Q∣M)
关系 DKL(P∣Q)=H(P,Q)−H(P)D_{\text{KL}}(P | Q) = H(P, Q) - H(P)DKL(P∣Q)=H(P,Q)−H(P)

A.2 温度缩放

公式 表达式
温度 softmax pi(T)=ezi/T∑jezj/Tp_i^{(T)} = \frac{e^{z_i/T}}{\sum_j e^{z_j/T}}pi(T)=∑jezj/Tezi/T
熵与温度 H(T)H(T)H(T) 单调递增,H(0)=0H(0) = 0H(0)=0,H(∞)=log⁡KH(\infty) = \log KH(∞)=logK
梯度 ∂pi∂zi=1Tpi(1−pi)\frac{\partial p_i}{\partial z_i} = \frac{1}{T} p_i (1 - p_i)∂zi∂pi=T1pi(1−pi)

A.3 蒸馏损失

方法 损失函数
Hinton 蒸馏 L=αH(y,PS(1))+(1−α)T2DKL(PT(T)∣PS(T))\mathcal{L} = \alpha H(\mathbf{y}, P_S^{(1)}) + (1-\alpha) T^2 D_{\text{KL}}(P_T^{(T)} | P_S^{(T)})L=αH(y,PS(1))+(1−α)T2DKL(PT(T)∣PS(T))
MSE 蒸馏 L=∣zT−zS∣2\mathcal{L} = |\mathbf{z}_T - \mathbf{z}_S|^2L=∣zT−zS∣2
余弦蒸馏 L=1−cos⁡(hT,hS)\mathcal{L} = 1 - \cos(\mathbf{h}_T, \mathbf{h}_S)L=1−cos(hT,hS)
注意力转移 L=∣AT−AS∣F2\mathcal{L} = |A_T - A_S|_F^2L=∣AT−AS∣F2
Gram 矩阵 L=∣G(AT)−G(AS)∣F2\mathcal{L} = |G(\mathbf{A}_T) - G(\mathbf{A}_S)|_F^2L=∣G(AT)−G(AS)∣F2

A.4 序列级蒸馏

公式 表达式
序列级 KL $D_{\text{KL}}^{\text{seq}} = \sum_{t=1}^{L} D_{\text{KL}}(P_T(\cdot
反向 KL $D_{\text{KL}}(P_S | P_T) = \sum_t \sum_v P_S(v

A.5 自蒸馏

方法 说明
Born-Again Networks 多代蒸馏,每代用上一代作为教师
EMA 自蒸馏 使用指数移动平均作为教师
多视角自蒸馏 不同增强版本相互对齐

参考文献

  1. Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the knowledge in a neural network. NeurIPS Workshop.
  2. Romero, A., et al. (2015). FitNets: Hints for thin deep nets. ICLR.
  3. Zagoruyko, S., & Komodakis, N. (2017). Paying more attention to attention: Improving the performance of convolutional neural networks via attention transfer. ICLR.
  4. Furlanello, T., et al. (2018). Born-again neural networks. ICML.
  5. Park, W., et al. (2019). Relational knowledge distillation. CVPR.
  6. Sanh, V., et al. (2019). DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter. NeurIPS Workshop.
  7. Jiao, X., et al. (2020). TinyBERT: Distilling BERT for natural language understanding. EMNLP.
  8. Gu, Y., et al. (2024). MiniLLM: Knowledge distillation of large language models. ICLR.
  9. Taori, R., et al. (2023). Stanford Alpaca: An instruction-following LLaMA model. GitHub.
  10. Chiang, W., et al. (2023). Vicuna: An open-source chatbot impressing GPT-4 with 90% ChatGPT quality. Blog.
  11. Xu, C., et al. (2023). WizardLM: Empowering large language models to follow complex instructions. arXiv.
  12. Wei, J., et al. (2022). Chain-of-thought prompting elicits reasoning in large language models. NeurIPS.
相关推荐
小丶舟1 小时前
6GB显卡跑Hermes Agent!开源AI自学习编程Agent实测
人工智能·学习·开源
haina20191 小时前
深圳市人工智能产业协会赴京考察海纳AI,共谋AI人才测评新生态
人工智能·ai面试·ai招聘
冷de猫1 小时前
从个人中转站到企业级 AI 网关:Aegisy 实践背后的基础设施演进逻辑
人工智能
穗余1 小时前
2026 AI x Web3 School共学营笔记-Day10-Women Builders in AI × Web3
人工智能·笔记·web3
wasp5201 小时前
# 推荐透明桌面 Widget 生产力工具 —— 待办、便签、AI常驻桌面:忙蜂了(BitzBee Todos)
人工智能·开源·gtd
2601_957879331 小时前
矩阵系统在企业数字化获客中的实践与价值分析
人工智能·数字营销·矩阵系统·企业运营
水上冰石1 小时前
comfui的sd1.5模型,有多少采样算法,详解每一个采样算法
人工智能·算法
Rocky Ding*1 小时前
一文读懂HiDream-I1稀疏 DiT 图像生成基础模型
论文阅读·人工智能·深度学习·机器学习·ai作画·aigc·ai-native
标书畅畅行1 小时前
2026 年 AI 标书工具市场观察:技术迭代与选型指南
大数据·人工智能