目录
- 第一部分:知识蒸馏基础理论
- 第一章:绪论------模型压缩的第三条路
- 第二章:知识蒸馏的数学基础------散度、温度与信息论
- 第三章:Hinton 蒸馏的理论分析
- 第二部分:经典蒸馏方法
- 第四章:响应蒸馏(Response-based Distillation)
- 第五章:特征蒸馏(Feature-based Distillation)
- 第六章:关系蒸馏(Relation-based Distillation)
- 第三部分:自蒸馏与无教师蒸馏
- 第七章:自蒸馏(Self-Distillation)
- 第八章:无教师蒸馏与数据蒸馏
- 第四部分:大语言模型的知识蒸馏
- 第九章:LLM 蒸馏的特殊挑战
- 第十章:白盒蒸馏------逐层与逐 token 对齐
- 第十一章:黑盒蒸馏------从 API 到合成数据
- 第十二章:思维链蒸馏与推理能力迁移
- 第五部分:完整可运行代码实现
- 第十三章:从零实现经典知识蒸馏
- 第十四章:从零实现特征蒸馏
- 第十五章:从零实现自蒸馏
- 第十六章:完整蒸馏 Pipeline 与精度对比
- 附录
第一部分:知识蒸馏基础理论
第一章:绪论------模型压缩的第三条路
1.1 模型压缩的三大范式
在前两篇文档中,我们分别讨论了量化 和剪枝。知识蒸馏是模型压缩的第三大范式:
| 范式 | 核心思想 | 压缩方式 | 是否需要训练 |
|---|---|---|---|
| 量化 | 降低参数精度 | 减少每个参数的比特数 | 可选(PTQ/QAT) |
| 剪枝 | 移除冗余参数 | 减少非零参数数量 | 可选 |
| 蒸馏 | 知识迁移 | 训练更小的模型 | 必须 |
1.1.1 知识蒸馏的核心思想
知识蒸馏(Knowledge Distillation, KD) (Hinton et al., 2015)的核心思想是:用一个大的教师模型(Teacher) 来指导一个小的学生模型(Student) 训练,使学生模型学习到教师模型的"知识"。
关键洞察 :教师模型的软标签(soft labels)------即 softmax 输出的概率分布------包含了比硬标签(one-hot)更丰富的信息。
例子:对于图像分类任务,一张"猫"的图片:
- 硬标签 :0,0,1,0,00, 0, 1, 0, 00,0,1,0,0(只有猫是 1)
- 软标签 :0.01,0.02,0.90,0.05,0.020.01, 0.02, 0.90, 0.05, 0.020.01,0.02,0.90,0.05,0.02(猫 0.90,但也有一定概率是其他动物)
软标签告诉学生:"这张图片虽然是猫,但和狗也有一定相似性"------这是硬标签无法传达的。
1.1.2 蒸馏 vs 量化 vs 剪枝
| 特性 | 量化 | 剪枝 | 蒸馏 |
|---|---|---|---|
| 压缩比 | 2x-4x | 2x-10x | 5x-100x |
| 精度保持 | 好 | 中等 | 好 |
| 计算加速 | 好 | 中等-好 | 最好 |
| 实现复杂度 | 低 | 中等 | 高 |
| 是否改变架构 | 否 | 可选 | 是 |
| 是否需要训练 | 可选 | 可选 | 必须 |
蒸馏的独特优势:可以改变模型架构------不仅减少参数量,还可以简化模型结构(如减少层数、隐藏维度等)。
1.2 蒸馏的历史
1.2.1 模型压缩的早期工作
模型集成(Model Ensemble) :多个模型的预测平均通常优于单个模型。但集成模型的计算成本是单模型的 KKK 倍。
Buciluǎ et al., 2006:首次提出将集成模型的知识压缩到单个模型中。
1.2.2 Hinton 的贡献
Hinton et al., 2015:提出了温度缩放的 softmax 和蒸馏损失函数,奠定了现代知识蒸馏的基础。
关键创新:
- 温度缩放 :使用温度参数 TTT 来"软化"softmax 输出
- 蒸馏损失:学生模型同时学习硬标签和软标签
- 暗知识(Dark Knowledge):软标签中包含的类间关系信息
第二章:知识蒸馏的数学基础------散度、温度与信息论
2.1 概率分布的距离度量
2.1.1 KL 散度
定义 2.1(Kullback-Leibler 散度) :对于两个离散概率分布 PPP 和 QQQ,KL 散度定义为:
DKL(P∥Q)=∑iP(i)logP(i)Q(i)D_{\text{KL}}(P \| Q) = \sum_i P(i) \log \frac{P(i)}{Q(i)}DKL(P∥Q)=i∑P(i)logQ(i)P(i)
性质:
- DKL(P∥Q)≥0D_{\text{KL}}(P \| Q) \geq 0DKL(P∥Q)≥0(非负性)
- DKL(P∥Q)=0 ⟺ P=QD_{\text{KL}}(P \| Q) = 0 \iff P = QDKL(P∥Q)=0⟺P=Q(同一性)
- DKL(P∥Q)≠DKL(Q∥P)D_{\text{KL}}(P \| Q) \neq D_{\text{KL}}(Q \| P)DKL(P∥Q)=DKL(Q∥P)(不对称性)
注意:KL 散度不是真正的距离度量(因为不对称)。
2.1.2 交叉熵
定义 2.2(交叉熵):
H(P,Q)=−∑iP(i)logQ(i)H(P, Q) = -\sum_i P(i) \log Q(i)H(P,Q)=−i∑P(i)logQ(i)
与 KL 散度的关系:
DKL(P∥Q)=H(P,Q)−H(P)D_{\text{KL}}(P \| Q) = H(P, Q) - H(P)DKL(P∥Q)=H(P,Q)−H(P)
其中 H(P)=−∑iP(i)logP(i)H(P) = -\sum_i P(i) \log P(i)H(P)=−∑iP(i)logP(i) 是 PPP 的熵。
推论 :最小化 DKL(P∥Q)D_{\text{KL}}(P \| Q)DKL(P∥Q) 等价于最小化 H(P,Q)H(P, Q)H(P,Q)(因为 H(P)H(P)H(P) 是常数)。
2.1.3 JS 散度
定义 2.3(Jensen-Shannon 散度):
DJS(P∥Q)=12DKL(P∥M)+12DKL(Q∥M)D_{\text{JS}}(P \| Q) = \frac{1}{2} D_{\text{KL}}(P \| M) + \frac{1}{2} D_{\text{KL}}(Q \| M)DJS(P∥Q)=21DKL(P∥M)+21DKL(Q∥M)
其中 M=(P+Q)/2M = (P + Q) / 2M=(P+Q)/2。
性质:
- DJS(P∥Q)≥0D_{\text{JS}}(P \| Q) \geq 0DJS(P∥Q)≥0
- DJS(P∥Q)=DJS(Q∥P)D_{\text{JS}}(P \| Q) = D_{\text{JS}}(Q \| P)DJS(P∥Q)=DJS(Q∥P)(对称性)
- DJS(P∥Q)≤log2D_{\text{JS}}(P \| Q) \leq \log 2DJS(P∥Q)≤log2(有界性)
2.1.4 散度选择的理论依据
定理 2.1(前向 vs 反向 KL):
在知识蒸馏中,使用前向 KL (DKL(PT∥PS)D_{\text{KL}}(P_T \| P_S)DKL(PT∥PS))还是反向 KL (DKL(PS∥PT)D_{\text{KL}}(P_S \| P_T)DKL(PS∥PT))会导致不同的行为:
- 前向 KL (DKL(PT∥PS)D_{\text{KL}}(P_T \| P_S)DKL(PT∥PS)):学生需要覆盖教师的所有模式(mode-covering)
- 反向 KL (DKL(PS∥PT)D_{\text{KL}}(P_S \| P_T)DKL(PS∥PT)):学生倾向于集中在教师的一个模式上(mode-seeking)
证明 :设教师分布 PTP_TPT 是双峰的,PT=0.5⋅N(μ1,σ2)+0.5⋅N(μ2,σ2)P_T = 0.5 \cdot \mathcal{N}(\mu_1, \sigma^2) + 0.5 \cdot \mathcal{N}(\mu_2, \sigma^2)PT=0.5⋅N(μ1,σ2)+0.5⋅N(μ2,σ2)。
前向 KL:DKL(PT∥PS)=−H(PT)+H(PT,PS)D_{\text{KL}}(P_T \| P_S) = -H(P_T) + H(P_T, P_S)DKL(PT∥PS)=−H(PT)+H(PT,PS)
当 PSP_SPS 是单峰高斯时,为了最小化 H(PT,PS)H(P_T, P_S)H(PT,PS),PSP_SPS 需要覆盖两个峰------导致 PSP_SPS 被"拉宽"。
反向 KL:DKL(PS∥PT)=−H(PS)+H(PS,PT)D_{\text{KL}}(P_S \| P_T) = -H(P_S) + H(P_S, P_T)DKL(PS∥PT)=−H(PS)+H(PS,PT)
当 PSP_SPS 是单峰高斯时,为了最小化 H(PS,PT)H(P_S, P_T)H(PS,PT),PSP_SPS 会"锁定"到其中一个峰。□\square□
实际意义 :在知识蒸馏中,通常使用反向 KL (即最小化 DKL(PS∥PT)D_{\text{KL}}(P_S \| P_T)DKL(PS∥PT)),因为这等价于最小化学生分布相对于教师分布的交叉熵。
2.2 温度缩放的 Softmax
2.2.1 标准 Softmax
定义 2.4(Softmax 函数) :对于 logits z=(z1,...,zK)\mathbf{z} = (z_1, \dots, z_K)z=(z1,...,zK):
pi=softmax(zi)=ezi∑j=1Kezjp_i = \text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^K e^{z_j}}pi=softmax(zi)=∑j=1Kezjezi
2.2.2 温度缩放
定义 2.5(温度缩放的 Softmax):
pi(T)=softmax(zi/T)=ezi/T∑j=1Kezj/Tp_i^{(T)} = \text{softmax}(z_i / T) = \frac{e^{z_i / T}}{\sum_{j=1}^K e^{z_j / T}}pi(T)=softmax(zi/T)=∑j=1Kezj/Tezi/T
其中 T>0T > 0T>0 是温度(temperature) 参数。
温度的影响:
- T=1T = 1T=1:标准 softmax
- T>1T > 1T>1:分布更"软"(更均匀),类间差异被平滑
- T<1T < 1T<1:分布更"硬"(更尖锐),类间差异被放大
- T→∞T \to \inftyT→∞:均匀分布 pi=1/Kp_i = 1/Kpi=1/K
- T→0T \to 0T→0:one-hot 分布(argmax)
2.2.3 温度的信息论解释
定理 2.2(温度与熵的关系):温度缩放的 softmax 输出的熵为:
H(T)=−∑ipi(T)logpi(T)H(T) = -\sum_i p_i^{(T)} \log p_i^{(T)}H(T)=−i∑pi(T)logpi(T)
且:
- H(T)H(T)H(T) 是 TTT 的单调递增函数
- H(0)=0H(0) = 0H(0)=0(one-hot,最小熵)
- H(∞)=logKH(\infty) = \log KH(∞)=logK(均匀分布,最大熵)
证明 :对 H(T)H(T)H(T) 关于 TTT 求导:
dHdT=1T2Varp(T)z≥0\frac{dH}{dT} = \frac{1}{T^2} \text{Var}_{p^{(T)}}z \geq 0dTdH=T21Varp(T)z≥0
其中 Varp(T)z=∑ipi(T)zi2−(∑ipi(T)zi)2\text{Var}_{p^{(T)}}z = \sum_i p_i^{(T)} z_i^2 - (\sum_i p_i^{(T)} z_i)^2Varp(T)z=∑ipi(T)zi2−(∑ipi(T)zi)2 是 logits 在分布 p(T)p^{(T)}p(T) 下的方差。□\square□
推论 2.1 :温度 TTT 控制了软标签中包含的"信息量"------TTT 越大,软标签越均匀,包含的信息越少(熵越大);TTT 越小,软标签越尖锐,包含的信息越多。
2.2.4 最优温度的选择
问题 :如何选择最优的温度 T∗T^*T∗?
定理 2.3(最优温度的信息论准则) :最优温度 T∗T^*T∗ 应使教师软标签的熵处于一个"合适"的范围:
Htarget≤H(T∗)≤Htarget+ΔH_{\text{target}} \leq H(T^*) \leq H_{\text{target}} + \DeltaHtarget≤H(T∗)≤Htarget+Δ
其中 HtargetH_{\text{target}}Htarget 是目标熵,Δ\DeltaΔ 是容许的熵范围。
直觉:
- TTT 太小:软标签接近硬标签,"暗知识"丢失
- TTT 太大:软标签过于均匀,有用信息被稀释
- 最优 TTT 通常在 2-20 之间
2.3 知识蒸馏的信息论框架
2.3.1 蒸馏作为信息传递
定理 2.4(蒸馏的信息瓶颈) :设教师模型的输出分布为 PTP_TPT,学生模型的输出分布为 PSP_SPS。蒸馏过程可以看作是最大化 PSP_SPS 和 PTP_TPT 之间的互信息:
maxPSI(PS;PT)=H(PS)−H(PS∣PT)\max_{P_S} I(P_S; P_T) = H(P_S) - H(P_S | P_T)PSmaxI(PS;PT)=H(PS)−H(PS∣PT)
在完美蒸馏的情况下,H(PS∣PT)=0H(P_S | P_T) = 0H(PS∣PT)=0,即 PSP_SPS 完全由 PTP_TPT 决定。
2.3.2 蒸馏的率失真分析
定理 2.5(蒸馏的率失真下界) :设教师模型有 KKK 个类别,学生模型使用 bbb 比特编码每个样本的预测。则蒸馏后的学生模型与教师模型之间的 KL 散度满足:
DKL(PT∥PS)≥12log22πeK−bD_{\text{KL}}(P_T \| P_S) \geq \frac{1}{2} \log_2 \frac{2\pi e}{K} - bDKL(PT∥PS)≥21log2K2πe−b
推论 :要达到很小的 KL 散度,学生模型需要足够大的"容量"(由 bbb 或等价地由模型参数量决定)。
第三章:Hinton 蒸馏的理论分析
3.1 Hinton 蒸馏的损失函数
3.1.1 标准蒸馏损失
Hinton et al., 2015 提出的蒸馏损失函数为:
L=α⋅Lhard+(1−α)⋅Lsoft\mathcal{L} = \alpha \cdot \mathcal{L}{\text{hard}} + (1 - \alpha) \cdot \mathcal{L}{\text{soft}}L=α⋅Lhard+(1−α)⋅Lsoft
其中:
硬标签损失(交叉熵):
Lhard=H(y,PS(1))=−∑iyilogPS,i(1)\mathcal{L}{\text{hard}} = H(\mathbf{y}, P_S^{(1)}) = -\sum_i y_i \log P{S,i}^{(1)}Lhard=H(y,PS(1))=−i∑yilogPS,i(1)
软标签损失 (KL 散度,温度 TTT):
Lsoft=T2⋅DKL(PT(T)∥PS(T))=T2⋅∑iPT,i(T)logPT,i(T)PS,i(T)\mathcal{L}{\text{soft}} = T^2 \cdot D{\text{KL}}(P_T^{(T)} \| P_S^{(T)}) = T^2 \cdot \sum_i P_{T,i}^{(T)} \log \frac{P_{T,i}^{(T)}}{P_{S,i}^{(T)}}Lsoft=T2⋅DKL(PT(T)∥PS(T))=T2⋅i∑PT,i(T)logPS,i(T)PT,i(T)
其中 y\mathbf{y}y 是真实标签(one-hot),PT(T)P_T^{(T)}PT(T) 和 PS(T)P_S^{(T)}PS(T) 分别是温度 TTT 下教师和学生的 softmax 输出。
3.1.2 为什么乘以 T2T^2T2?
定理 3.1(T2T^2T2 缩放的必要性) :在温度 TTT 下,softmax 输出的梯度量级约为 1/T1/T1/T。因此 KL 散度的梯度量级约为 1/T21/T^21/T2。乘以 T2T^2T2 可以使软标签损失的梯度量级与硬标签损失一致。
证明 :设 logits 为 z\mathbf{z}z,温度缩放的 softmax 为 pi=ezi/T/∑jezj/Tp_i = e^{z_i/T} / \sum_j e^{z_j/T}pi=ezi/T/∑jezj/T。
对 ziz_izi 求导:
∂pi∂zi=1Tpi(1−pi)\frac{\partial p_i}{\partial z_i} = \frac{1}{T} p_i (1 - p_i)∂zi∂pi=T1pi(1−pi)
∂pj∂zi=−1Tpipj,j≠i\frac{\partial p_j}{\partial z_i} = -\frac{1}{T} p_i p_j, \quad j \neq i∂zi∂pj=−T1pipj,j=i
因此 ∂p/∂z=O(1/T)\partial p / \partial z = O(1/T)∂p/∂z=O(1/T)。
KL 散度的梯度:∂DKL∂z=O(1/T)\frac{\partial D_{\text{KL}}}{\partial z} = O(1/T)∂z∂DKL=O(1/T)。
乘以 T2T^2T2 后:T2⋅∂DKL∂z=O(T)T^2 \cdot \frac{\partial D_{\text{KL}}}{\partial z} = O(T)T2⋅∂z∂DKL=O(T)------当 TTT 较大时,梯度量级约为 TTT。
实际上,更精确的分析表明 T2T^2T2 缩放使得软标签损失的梯度在 TTT 变化时保持相对稳定。□\square□
3.1.3 α\alphaα 的选择
经验规则:
- α=0.1\alpha = 0.1α=0.1:更重视软标签(常见选择)
- α=0.5\alpha = 0.5α=0.5:软标签和硬标签同等重要
- α=0.9\alpha = 0.9α=0.9:更重视硬标签
定理 3.2(α\alphaα 的最优值) :在一定假设下,最优的 α\alphaα 满足:
α∗=H(PT(T))H(PT(T))+H(y)\alpha^* = \frac{H(P_T^{(T)})}{H(P_T^{(T)}) + H(\mathbf{y})}α∗=H(PT(T))+H(y)H(PT(T))
其中 H(PT(T))H(P_T^{(T)})H(PT(T)) 是教师软标签的熵,H(y)=0H(\mathbf{y}) = 0H(y)=0(硬标签的熵为零)。
因此 α∗→0\alpha^* \to 0α∗→0------理论上应该完全使用软标签。但实践中,保留一定比例的硬标签有助于稳定训练。
3.2 蒸馏的梯度分析
3.2.1 软标签损失的梯度
定理 3.3(软标签梯度的分解) :软标签损失关于学生 logits ziz_izi 的梯度为:
∂Lsoft∂zi=T⋅(PS,i(T)−PT,i(T))\frac{\partial \mathcal{L}{\text{soft}}}{\partial z_i} = T \cdot (P{S,i}^{(T)} - P_{T,i}^{(T)})∂zi∂Lsoft=T⋅(PS,i(T)−PT,i(T))
证明:
Lsoft=T2∑jPT,j(T)logPT,j(T)PS,j(T)=T2(∑jPT,j(T)logPT,j(T)−∑jPT,j(T)logPS,j(T))\mathcal{L}{\text{soft}} = T^2 \sum_j P{T,j}^{(T)} \log \frac{P_{T,j}^{(T)}}{P_{S,j}^{(T)}} = T^2 \left(\sum_j P_{T,j}^{(T)} \log P_{T,j}^{(T)} - \sum_j P_{T,j}^{(T)} \log P_{S,j}^{(T)}\right)Lsoft=T2j∑PT,j(T)logPS,j(T)PT,j(T)=T2(j∑PT,j(T)logPT,j(T)−j∑PT,j(T)logPS,j(T))
第一项与学生参数无关。对第二项求导:
∂∂zi(−T2∑jPT,j(T)logPS,j(T))=−T2∑jPT,j(T)1PS,j(T)∂PS,j(T)∂zi\frac{\partial}{\partial z_i} \left(-T^2 \sum_j P_{T,j}^{(T)} \log P_{S,j}^{(T)}\right) = -T^2 \sum_j P_{T,j}^{(T)} \frac{1}{P_{S,j}^{(T)}} \frac{\partial P_{S,j}^{(T)}}{\partial z_i}∂zi∂(−T2j∑PT,j(T)logPS,j(T))=−T2j∑PT,j(T)PS,j(T)1∂zi∂PS,j(T)
使用 ∂PS,j(T)∂zi=1TPS,j(T)(δij−PS,i(T))\frac{\partial P_{S,j}^{(T)}}{\partial z_i} = \frac{1}{T} P_{S,j}^{(T)} (\delta_{ij} - P_{S,i}^{(T)})∂zi∂PS,j(T)=T1PS,j(T)(δij−PS,i(T)):
=−T2∑jPT,j(T)1PS,j(T)⋅1TPS,j(T)(δij−PS,i(T))= -T^2 \sum_j P_{T,j}^{(T)} \frac{1}{P_{S,j}^{(T)}} \cdot \frac{1}{T} P_{S,j}^{(T)} (\delta_{ij} - P_{S,i}^{(T)})=−T2j∑PT,j(T)PS,j(T)1⋅T1PS,j(T)(δij−PS,i(T))
=−T∑jPT,j(T)(δij−PS,i(T))=−T(PT,i(T)−PS,i(T)∑jPT,j(T))= -T \sum_j P_{T,j}^{(T)} (\delta_{ij} - P_{S,i}^{(T)}) = -T (P_{T,i}^{(T)} - P_{S,i}^{(T)} \sum_j P_{T,j}^{(T)})=−Tj∑PT,j(T)(δij−PS,i(T))=−T(PT,i(T)−PS,i(T)j∑PT,j(T))
=T(PS,i(T)−PT,i(T))= T (P_{S,i}^{(T)} - P_{T,i}^{(T)})=T(PS,i(T)−PT,i(T))
□\square□
3.2.2 软标签梯度的物理含义
推论 3.1 :软标签损失的梯度与学生和教师输出的差异成正比------差异越大,梯度越大,学生更新越快。
推论 3.2 :当 TTT 很大时,PS,i(T)≈PT,i(T)≈1/KP_{S,i}^{(T)} \approx P_{T,i}^{(T)} \approx 1/KPS,i(T)≈PT,i(T)≈1/K,梯度接近零------高温下蒸馏的"学习信号"很弱。
推论 3.3 :当 TTT 很小时,PS,i(T)P_{S,i}^{(T)}PS,i(T) 和 PT,i(T)P_{T,i}^{(T)}PT,i(T) 接近 one-hot,梯度变得稀疏------只有被错误分类的样本有非零梯度。
最优温度:在"信息量"和"梯度强度"之间取得平衡。
3.3 蒸馏的泛化分析
3.3.1 蒸馏作为正则化
定理 3.4(蒸馏的正则化效应):蒸馏损失等价于在标准交叉熵损失上添加一个正则化项:
L=H(y,PS)+λ⋅DKL(PT∥PS)\mathcal{L} = H(\mathbf{y}, P_S) + \lambda \cdot D_{\text{KL}}(P_T \| P_S)L=H(y,PS)+λ⋅DKL(PT∥PS)
其中 λ=(1−α)T2/α\lambda = (1 - \alpha) T^2 / \alphaλ=(1−α)T2/α。
推论 :蒸馏的正则化效应类似于标签平滑(Label Smoothing)------将硬标签替换为软标签。
3.3.2 标签平滑与蒸馏的关系
定义 3.1(标签平滑):
y~i=(1−ϵ)yi+ϵK\tilde{y}_i = (1 - \epsilon) y_i + \frac{\epsilon}{K}y~i=(1−ϵ)yi+Kϵ
其中 ϵ\epsilonϵ 是平滑参数。
定理 3.5(蒸馏与标签平滑的等价性):当教师模型的输出是均匀分布时,蒸馏退化为标签平滑:
PT,i(T)=1K ⟹ Lsoft=T2DKL(1K∥PS(T))=T2(logK+∑i1KlogPS,i(T))P_{T,i}^{(T)} = \frac{1}{K} \implies \mathcal{L}{\text{soft}} = T^2 D{\text{KL}}\left(\frac{1}{K} \| P_S^{(T)}\right) = T^2 \left(\log K + \sum_i \frac{1}{K} \log P_{S,i}^{(T)}\right)PT,i(T)=K1⟹Lsoft=T2DKL(K1∥PS(T))=T2(logK+i∑K1logPS,i(T))
这与标签平滑的损失形式一致。□\square□
关键区别 :教师模型的输出不是均匀分布------它包含了类间的结构信息。这正是蒸馏比标签平滑更有效的原因。
第二部分:经典蒸馏方法
第四章:响应蒸馏(Response-based Distillation)
4.1 基本框架
4.1.1 问题形式化
给定教师模型 fTf_TfT 和学生模型 fSf_SfS,输入 x\mathbf{x}x,类别数 KKK:
- 教师 logits:zT=fT(x)∈RK\mathbf{z}_T = f_T(\mathbf{x}) \in \mathbb{R}^KzT=fT(x)∈RK
- 学生 logits:zS=fS(x)∈RK\mathbf{z}_S = f_S(\mathbf{x}) \in \mathbb{R}^KzS=fS(x)∈RK
- 教师软标签:pT=softmax(zT/T)\mathbf{p}_T = \text{softmax}(\mathbf{z}_T / T)pT=softmax(zT/T)
- 学生软标签:pS=softmax(zS/T)\mathbf{p}_S = \text{softmax}(\mathbf{z}_S / T)pS=softmax(zS/T)
4.1.2 蒸馏损失
标准 KL 蒸馏:
LKD=T2⋅DKL(pT∥pS)\mathcal{L}{\text{KD}} = T^2 \cdot D{\text{KL}}(\mathbf{p}_T \| \mathbf{p}_S)LKD=T2⋅DKL(pT∥pS)
交叉熵蒸馏(等价):
LCE=−T2∑ipT,ilogpS,i\mathcal{L}{\text{CE}} = -T^2 \sum_i p{T,i} \log p_{S,i}LCE=−T2i∑pT,ilogpS,i
(因为 DKL(P∥Q)=H(P,Q)−H(P)D_{\text{KL}}(P \| Q) = H(P, Q) - H(P)DKL(P∥Q)=H(P,Q)−H(P),H(P)H(P)H(P) 是常数。)
4.1.3 总损失
Ltotal=α⋅H(y,pS(1))+(1−α)⋅T2⋅DKL(pT(T)∥pS(T))\mathcal{L}_{\text{total}} = \alpha \cdot H(\mathbf{y}, \mathbf{p}S^{(1)}) + (1 - \alpha) \cdot T^2 \cdot D{\text{KL}}(\mathbf{p}_T^{(T)} \| \mathbf{p}_S^{(T)})Ltotal=α⋅H(y,pS(1))+(1−α)⋅T2⋅DKL(pT(T)∥pS(T))
4.2 响应蒸馏的变体
4.2.1 MSE 蒸馏
定义:直接最小化 logits 之间的均方误差:
LMSE=∥zT−zS∥2\mathcal{L}_{\text{MSE}} = \|\mathbf{z}_T - \mathbf{z}_S\|^2LMSE=∥zT−zS∥2
与 KL 蒸馏的比较:
定理 4.1(MSE 与 KL 的关系):在 logits 的小扰动假设下:
DKL(pT∥pS)≈12T2∑ipT,i(zT,i−zS,i)2−12T2(∑ipT,i(zT,i−zS,i))2D_{\text{KL}}(\mathbf{p}T \| \mathbf{p}S) \approx \frac{1}{2T^2} \sum_i p{T,i} (z{T,i} - z_{S,i})^2 - \frac{1}{2T^2} \left(\sum_i p_{T,i} (z_{T,i} - z_{S,i})\right)^2DKL(pT∥pS)≈2T21i∑pT,i(zT,i−zS,i)2−2T21(i∑pT,i(zT,i−zS,i))2
当 pT\mathbf{p}TpT 接近均匀分布时,DKL≈12T2VarpTzT−zSD{\text{KL}} \approx \frac{1}{2T^2} \text{Var}_{\mathbf{p}_T}\\mathbf{z}_T - \\mathbf{z}_SDKL≈2T21VarpTzT−zS。
推论:MSE 蒸馏和 KL 蒸馏在某些条件下近似等价,但 KL 蒸馏更鲁棒。
4.2.2 残差蒸馏
定义 :学生学习教师的残差而非绝对值:
Lres=∥(zT−zS)−r(zS)∥2\mathcal{L}_{\text{res}} = \|(\mathbf{z}_T - \mathbf{z}_S) - r(\mathbf{z}_S)\|^2Lres=∥(zT−zS)−r(zS)∥2
其中 rrr 是一个可学习的残差网络。
4.2.3 对抗蒸馏
定义:使用对抗训练来增强蒸馏:
minfSmaxDLKD(fT,fS)+λLadv(fS,D)\min_{f_S} \max_{D} \mathcal{L}{\text{KD}}(f_T, f_S) + \lambda \mathcal{L}{\text{adv}}(f_S, D)fSminDmaxLKD(fT,fS)+λLadv(fS,D)
其中 DDD 是判别器,试图区分教师和学生的输出。
第五章:特征蒸馏(Feature-based Distillation)
5.1 中间层特征对齐
5.1.1 动机
问题:响应蒸馏只利用了教师模型的最终输出,忽略了中间层的丰富信息。
特征蒸馏:对齐教师和学生中间层的特征表示。
5.1.2 FitNets
Romero et al., 2015 提出的 FitNets 是最早的特征蒸馏方法。
定义 5.1(FitNet 损失):
LFitNet=∥fT(l)(x)−r(fS(l′)(x))∥2\mathcal{L}_{\text{FitNet}} = \|f_T^{(l)}(\mathbf{x}) - r(f_S^{(l')}(\mathbf{x}))\|^2LFitNet=∥fT(l)(x)−r(fS(l′)(x))∥2
其中:
- fT(l)f_T^{(l)}fT(l) 是教师第 lll 层的特征
- fS(l′)f_S^{(l')}fS(l′) 是学生第 l′l'l′ 层的特征
- rrr 是一个可学习的回归网络(用于维度匹配)
问题 :教师和学生的特征维度可能不同,需要 rrr 来匹配维度。
5.1.3 特征蒸馏的理论分析
定理 5.1(特征蒸馏的信息论解释):特征蒸馏等价于最大化教师特征和学生特征之间的互信息:
maxfSI(fT(l)(x);fS(l′)(x))\max_{f_S} I(f_T^{(l)}(\mathbf{x}); f_S^{(l')}(\mathbf{x}))fSmaxI(fT(l)(x);fS(l′)(x))
证明 :设 hT=fT(l)(x)\mathbf{h}_T = f_T^{(l)}(\mathbf{x})hT=fT(l)(x),hS=fS(l′)(x)\mathbf{h}_S = f_S^{(l')}(\mathbf{x})hS=fS(l′)(x)。在高斯假设下:
I(hT;hS)=12log∣Cov(hT)∣∣Cov(hT∣hS)∣I(\mathbf{h}_T; \mathbf{h}_S) = \frac{1}{2} \log \frac{|\text{Cov}(\mathbf{h}_T)|}{|\text{Cov}(\mathbf{h}_T | \mathbf{h}_S)|}I(hT;hS)=21log∣Cov(hT∣hS)∣∣Cov(hT)∣
最小化 MSE ∥hT−hS∥2\|\mathbf{h}_T - \mathbf{h}_S\|^2∥hT−hS∥2 等价于最大化条件概率 p(hT∣hS)p(\mathbf{h}_T | \mathbf{h}_S)p(hT∣hS) 的集中度,即最大化互信息。□\square□
5.2 特征蒸馏的变体
5.2.1 注意力转移(Attention Transfer)
Zagoruyko & Komodakis, 2017 提出对齐注意力图:
定义 5.2(注意力图) :对于特征图 A∈RC×H×W\mathbf{A} \in \mathbb{R}^{C \times H \times W}A∈RC×H×W,注意力图定义为:
A(A)=1C∑c=1C∣Ac∣p\mathcal{A}(\mathbf{A}) = \frac{1}{C} \sum_{c=1}^{C} |A_c|^pA(A)=C1c=1∑C∣Ac∣p
其中 ppp 是幂次参数(通常 p=2p = 2p=2)。
注意力转移损失:
LAT=∑l∥A(AT(l))∥A(AT(l))∥2−A(AS(l))∥A(AS(l))∥2∥22\mathcal{L}_{\text{AT}} = \sum_l \left\|\frac{\mathcal{A}(\mathbf{A}_T^{(l)})}{\|\mathcal{A}(\mathbf{A}_T^{(l)})\|_2} - \frac{\mathcal{A}(\mathbf{A}_S^{(l)})}{\|\mathcal{A}(\mathbf{A}_S^{(l)})\|_2}\right\|_2^2LAT=l∑ ∥A(AT(l))∥2A(AT(l))−∥A(AS(l))∥2A(AS(l)) 22
5.2.2 特征分布对齐
定义 5.3(分布对齐损失):
Ldist=∑lDKL(p(hT(l))∥p(hS(l)))\mathcal{L}{\text{dist}} = \sum_l D{\text{KL}}(p(\mathbf{h}_T^{(l)}) \| p(\mathbf{h}_S^{(l)}))Ldist=l∑DKL(p(hT(l))∥p(hS(l)))
在高斯假设下,p(h)=N(μ,Σ)p(\mathbf{h}) = \mathcal{N}(\mu, \Sigma)p(h)=N(μ,Σ),KL 散度有解析形式:
DKL(NT∥NS)=12tr(ΣS−1ΣT)−k+(μS−μT)TΣS−1(μS−μT)+log∣ΣS∣∣ΣT∣D_{\text{KL}}(\mathcal{N}_T \| \mathcal{N}_S) = \frac{1}{2}\left\\text{tr}(\\Sigma_S\^{-1} \\Sigma_T) - k + (\\mu_S - \\mu_T)\^T \\Sigma_S\^{-1} (\\mu_S - \\mu_T) + \\log \\frac{\|\\Sigma_S\|}{\|\\Sigma_T\|}\\rightDKL(NT∥NS)=21tr(ΣS−1ΣT)−k+(μS−μT)TΣS−1(μS−μT)+log∣ΣT∣∣ΣS∣
5.2.3 Gram 矩阵匹配
定义 5.4(Gram 矩阵) :对于特征图 A∈RC×HW\mathbf{A} \in \mathbb{R}^{C \times HW}A∈RC×HW,Gram 矩阵定义为:
G(A)=AAT∈RC×CG(\mathbf{A}) = \mathbf{A} \mathbf{A}^T \in \mathbb{R}^{C \times C}G(A)=AAT∈RC×C
Gram 矩阵匹配损失:
LGram=∑l∥G(AT(l))−G(AS(l))∥F2\mathcal{L}_{\text{Gram}} = \sum_l \|G(\mathbf{A}_T^{(l)}) - G(\mathbf{A}_S^{(l)})\|_F^2LGram=l∑∥G(AT(l))−G(AS(l))∥F2
定理 5.2(Gram 矩阵的信息含量) :Gram 矩阵 G(A)G(\mathbf{A})G(A) 编码了特征通道之间的二阶统计关系------即哪些通道倾向于同时被激活。
第六章:关系蒸馏(Relation-based Distillation)
6.1 样本间关系
6.1.1 动机
问题 :响应蒸馏和特征蒸馏都是逐样本对齐的------每个样本独立处理。但教师模型的"知识"也包含在样本间的关系中。
6.1.2 结构化蒸馏
Park et al., 2019 提出结构化知识蒸馏(Structured Knowledge Distillation):
定义 6.1(样本间距离矩阵):
Dij=∥f(xi)−f(xj)∥2D_{ij} = \|f(\mathbf{x}_i) - f(\mathbf{x}_j)\|_2Dij=∥f(xi)−f(xj)∥2
结构化蒸馏损失:
Lstruct=∥DT−DS∥F2\mathcal{L}_{\text{struct}} = \|D_T - D_S\|_F^2Lstruct=∥DT−DS∥F2
其中 DTD_TDT 和 DSD_SDS 分别是教师和学生的距离矩阵。
6.1.3 关系蒸馏的理论分析
定理 6.1(距离矩阵保留类间结构) :设教师模型将样本映射到特征空间,同类样本聚集、异类样本分散。则距离矩阵 DTD_TDT 编码了类间和类内的结构信息。
证明 :设 xi\mathbf{x}ixi 和 xj\mathbf{x}jxj 属于同一类,则 DT,ijD{T,ij}DT,ij 较小;若属于不同类,则 DT,ijD{T,ij}DT,ij 较大。学生对齐 DTD_TDT 等价于学习这种类间/类内结构。□\square□
6.2 图蒸馏
6.2.1 图结构的关系
定义 6.2(关系图) :构建一个图 G=(V,E)G = (V, E)G=(V,E),其中节点是样本,边权重是样本间的相似度:
wij=exp(−∥f(xi)−f(xj)∥22σ2)w_{ij} = \exp\left(-\frac{\|f(\mathbf{x}_i) - f(\mathbf{x}_j)\|^2}{2\sigma^2}\right)wij=exp(−2σ2∥f(xi)−f(xj)∥2)
图蒸馏损失:
Lgraph=∥WT−WS∥F2\mathcal{L}_{\text{graph}} = \|W_T - W_S\|_F^2Lgraph=∥WT−WS∥F2
其中 WTW_TWT 和 WSW_SWS 分别是教师和学生的邻接矩阵。
6.2.2 图蒸馏的理论分析
定理 6.2(图蒸馏与核方法的关系) :图蒸馏等价于对齐教师和学生的核矩阵:
KT(xi,xj)=⟨ϕT(xi),ϕT(xj)⟩K_T(\mathbf{x}_i, \mathbf{x}_j) = \langle\phi_T(\mathbf{x}_i), \phi_T(\mathbf{x}_j)\rangleKT(xi,xj)=⟨ϕT(xi),ϕT(xj)⟩
其中 ϕT\phi_TϕT 是教师的特征映射。
推论 :对齐核矩阵等价于对齐特征空间的几何结构------距离、角度、相对位置。
第三部分:自蒸馏与无教师蒸馏
第七章:自蒸馏(Self-Distillation)
7.1 核心思想
7.1.1 什么是自蒸馏?
自蒸馏(Self-Distillation):模型从自身的"旧版本"或"不同视角"学习,不需要外部教师模型。
7.1.2 自蒸馏的动机
问题:知识蒸馏需要一个大的教师模型,增加了计算成本和部署复杂度。
自蒸馏的优势:
- 不需要额外的教师模型
- 训练成本与标准训练相当
- 可以在训练过程中持续改进
7.2 Born-Again Networks
7.2.1 方法
Furlanello et al., 2018 提出的 Born-Again Networks(BAN):
算法 7.1(Born-Again Networks):
1. 训练教师模型 f_T(标准训练)
2. 训练学生模型 f_S(与 f_T 架构相同),使用 f_T 作为教师
3. 将 f_S 作为新的教师,训练新的学生
4. 重复 K 次
关键发现:即使学生和教师架构相同,蒸馏后的学生模型通常优于教师模型!
7.2.2 Born-Again Networks 的理论分析
定理 7.1(BAN 的收敛性) :设第 kkk 代模型的测试准确率为 Acck\text{Acc}_kAcck。在一定假设下:
Acck+1≥Acck−ϵk\text{Acc}_{k+1} \geq \text{Acc}_k - \epsilon_kAcck+1≥Acck−ϵk
其中 ϵk→0\epsilon_k \to 0ϵk→0(随着代数增加,改善量减小)。
直觉:每一代蒸馏相当于一次"软标签增强"的训练------软标签提供了比硬标签更丰富的监督信号。
7.3 自蒸馏的变体
7.3.1 在线自蒸馏
定义 :在训练过程中,使用模型的指数移动平均(EMA) 作为教师:
θEMA(t)=β⋅θEMA(t−1)+(1−β)⋅θ(t)\theta_{\text{EMA}}^{(t)} = \beta \cdot \theta_{\text{EMA}}^{(t-1)} + (1 - \beta) \cdot \theta^{(t)}θEMA(t)=β⋅θEMA(t−1)+(1−β)⋅θ(t)
L=H(y,PS)+λ⋅DKL(PEMA∥PS)\mathcal{L} = H(\mathbf{y}, P_S) + \lambda \cdot D_{\text{KL}}(P_{\text{EMA}} \| P_S)L=H(y,PS)+λ⋅DKL(PEMA∥PS)
7.3.2 多视角自蒸馏
定义:对同一输入施加不同的增强,让不同增强的输出相互对齐:
Lself=∑i≠jDKL(PS(x(i))∥PS(x(j)))\mathcal{L}{\text{self}} = \sum{i \neq j} D_{\text{KL}}(P_S(\mathbf{x}^{(i)}) \| P_S(\mathbf{x}^{(j)}))Lself=i=j∑DKL(PS(x(i))∥PS(x(j)))
其中 x(i)\mathbf{x}^{(i)}x(i) 和 x(j)\mathbf{x}^{(j)}x(j) 是同一输入的不同增强版本。
第八章:无教师蒸馏与数据蒸馏
8.1 无教师蒸馏
8.1.1 动机
问题:在某些场景下,没有现成的教师模型------例如:
- 模型太大无法部署为教师
- 数据隐私限制
- 需要从头训练
8.1.2 虚拟教师
思想 :使用集成模型 或多个检查点作为虚拟教师。
PT=1M∑m=1MPS(m)P_T = \frac{1}{M} \sum_{m=1}^{M} P_{S}^{(m)}PT=M1m=1∑MPS(m)
其中 PS(m)P_S^{(m)}PS(m) 是训练过程中不同检查点的输出。
8.2 数据蒸馏
8.2.1 核心思想
数据蒸馏(Dataset Distillation) :将大数据集压缩为一个小的合成数据集,使在小数据集上训练的模型能达到接近在大数据集上训练的精度。
8.2.2 形式化
问题:
minD~Ex∼DL(fθ(D\~)(x))\min_{\tilde{\mathcal{D}}} \mathbb{E}_{\mathbf{x} \sim \mathcal{D}} \left\\mathcal{L}(f_{\\theta(\\tilde{\\mathcal{D}})}(\\mathbf{x}))\\rightD~minEx∼DL(fθ(D\~)(x))
其中 θ(D~)=argminθ∑x~∈D~L(fθ(x~))\theta(\tilde{\mathcal{D}}) = \arg\min_\theta \sum_{\tilde{\mathbf{x}} \in \tilde{\mathcal{D}}} \mathcal{L}(f_\theta(\tilde{\mathbf{x}}))θ(D~)=argminθ∑x~∈D~L(fθ(x~)) 是在合成数据集上训练的模型参数。
8.2.3 数据蒸馏的理论分析
定理 8.1(数据蒸馏的信息瓶颈) :合成数据集 D~\tilde{\mathcal{D}}D~ 的信息量受限于其大小 ∣D~∣|\tilde{\mathcal{D}}|∣D~∣:
I(D~;D)≤∣D~∣⋅log∣X∣I(\tilde{\mathcal{D}}; \mathcal{D}) \leq |\tilde{\mathcal{D}}| \cdot \log |\mathcal{X}|I(D~;D)≤∣D~∣⋅log∣X∣
其中 ∣X∣|\mathcal{X}|∣X∣ 是输入空间的大小。
推论:合成数据集的大小必须足够大,才能保留原始数据集的信息。
第四部分:大语言模型的知识蒸馏
第九章:LLM 蒸馏的特殊挑战
9.1 LLM 蒸馏 vs 传统蒸馏
| 挑战 | 传统蒸馏 | LLM 蒸馏 |
|---|---|---|
| 模型规模 | 百万-千万参数 | 十亿-千亿参数 |
| 输出空间 | 分类(几千类) | 生成(词汇表 32K-100K) |
| 序列长度 | 固定 | 可变(数百-数千) |
| 计算成本 | 低 | 极高 |
| 训练数据 | 充足 | 有限或受限 |
9.2 LLM 蒸馏的分类
9.2.1 白盒蒸馏
白盒蒸馏 :可以访问教师模型的内部表示(logits、特征、注意力权重等)。
优势:可以利用丰富的中间信息。
代表方法:DistilBERT、TinyBERT、MiniLLM。
9.2.2 黑盒蒸馏
黑盒蒸馏 :只能访问教师模型的输出文本(通过 API)。
优势:不需要部署教师模型,适用于闭源模型。
代表方法:Alpaca、Vicuna、WizardLM。
9.3 序列级蒸馏的挑战
9.3.1 自回归生成的特殊性
问题:LLM 是自回归模型------每个 token 的生成依赖于之前的 token。这导致:
- 暴露偏差(Exposure Bias):训练时使用真实 token,推理时使用模型生成的 token
- 长程依赖:早期 token 的误差会传播到后续 token
- 序列长度不固定:不同样本的序列长度不同
9.3.2 序列级 KL 散度
定义 9.1(序列级 KL 散度) :对于长度为 LLL 的序列 y=(y1,...,yL)\mathbf{y} = (y_1, \dots, y_L)y=(y1,...,yL):
DKLseq(PT∥PS)=∑t=1LDKL(PT(⋅∣y<t)∥PS(⋅∣y<t))D_{\text{KL}}^{\text{seq}}(P_T \| P_S) = \sum_{t=1}^{L} D_{\text{KL}}(P_T(\cdot | y_{<t}) \| P_S(\cdot | y_{<t}))DKLseq(PT∥PS)=t=1∑LDKL(PT(⋅∣y<t)∥PS(⋅∣y<t))
其中 PT(⋅∣y<t)P_T(\cdot | y_{<t})PT(⋅∣y<t) 是教师在给定前缀 y<ty_{<t}y<t 下的下一个 token 分布。
定理 9.1(序列级 KL 的分解):
DKLseq=∑t=1L∑v∈VPT(v∣y<t)logPT(v∣y<t)PS(v∣y<t)D_{\text{KL}}^{\text{seq}} = \sum_{t=1}^{L} \sum_{v \in \mathcal{V}} P_T(v | y_{<t}) \log \frac{P_T(v | y_{<t})}{P_S(v | y_{<t})}DKLseq=t=1∑Lv∈V∑PT(v∣y<t)logPS(v∣y<t)PT(v∣y<t)
计算挑战 :词汇表 V\mathcal{V}V 的大小通常为 32K-100K,每个时间步都需要计算整个词汇表的分布。
第十章:白盒蒸馏------逐层与逐 token 对齐
10.1 DistilBERT
10.1.1 方法
DistilBERT(Sanh et al., 2019)将 BERT-base(12 层)蒸馏到一个 6 层的学生模型。
损失函数:
L=αLCE+βLMLM+γLcos\mathcal{L} = \alpha \mathcal{L}{\text{CE}} + \beta \mathcal{L}{\text{MLM}} + \gamma \mathcal{L}_{\cos}L=αLCE+βLMLM+γLcos
其中:
- LCE\mathcal{L}_{\text{CE}}LCE:软标签交叉熵(蒸馏损失)
- LMLM\mathcal{L}_{\text{MLM}}LMLM:掩码语言模型损失(任务损失)
- Lcos\mathcal{L}_{\cos}Lcos:教师和学生最后一层隐藏状态的余弦相似度
10.1.2 层选择策略
问题:12 层教师如何映射到 6 层学生?
策略 1:均匀间隔 ------学生第 lll 层学习教师第 2l2l2l 层
策略 2:首尾共享------学生第 1 层和最后一层分别学习教师的第 1 层和最后一层,中间层均匀间隔
策略 3:可学习映射------学习一个映射函数,自动选择教师层
10.2 TinyBERT
10.2.1 方法
TinyBERT(Jiao et al., 2020)提出了更精细的蒸馏策略:
三阶段蒸馏:
- 嵌入层蒸馏:对齐嵌入层的输出
- Transformer 层蒸馏:对齐注意力权重和隐藏状态
- 预测层蒸馏:对齐最终输出
损失函数:
L=∑l∈SλlLlayer(fT(m(l)),fS(l))\mathcal{L} = \sum_{l \in \mathcal{S}} \lambda_l \mathcal{L}_{\text{layer}}(f_T^{(m(l))}, f_S^{(l)})L=l∈S∑λlLlayer(fT(m(l)),fS(l))
其中 S\mathcal{S}S 是选中的层集合,m(l)m(l)m(l) 是学生第 lll 层对应的教师层。
10.2.2 注意力蒸馏
定义 10.1(注意力蒸馏损失):
Lattn=1h∑i=1h∥AT(i)−AS(i)∥F2\mathcal{L}{\text{attn}} = \frac{1}{h} \sum{i=1}^{h} \|A_T^{(i)} - A_S^{(i)}\|_F^2Lattn=h1i=1∑h∥AT(i)−AS(i)∥F2
其中 AT(i)A_T^{(i)}AT(i) 和 AS(i)A_S^{(i)}AS(i) 分别是教师和学生第 iii 个注意力头的注意力权重矩阵。
10.2.3 隐藏状态蒸馏
定义 10.2(隐藏状态蒸馏损失):
Lhidden=∥M⋅HS−HT∥F2\mathcal{L}_{\text{hidden}} = \|M \cdot H_S - H_T\|_F^2Lhidden=∥M⋅HS−HT∥F2
其中 MMM 是一个可学习的线性变换,用于维度匹配。
10.3 MiniLLM
10.3.1 动机
问题 :标准 KL 蒸馏在 LLM 中会导致长度偏差------学生倾向于生成更短的序列。
10.3.2 反向 KL 蒸馏
MiniLLM (Gu et al., 2024)提出使用反向 KL 散度:
LMiniLLM=DKL(PS∥PT)=∑t∑vPS(v∣y<t)logPS(v∣y<t)PT(v∣y<t)\mathcal{L}{\text{MiniLLM}} = D{\text{KL}}(P_S \| P_T) = \sum_t \sum_v P_S(v | y_{<t}) \log \frac{P_S(v | y_{<t})}{P_T(v | y_{<t})}LMiniLLM=DKL(PS∥PT)=t∑v∑PS(v∣y<t)logPT(v∣y<t)PS(v∣y<t)
定理 10.1(反向 KL 避免长度偏差) :前向 KL DKL(PT∥PS)D_{\text{KL}}(P_T \| P_S)DKL(PT∥PS) 倾向于覆盖教师分布的所有模式(mode-covering),导致学生分配概率给低概率 token,从而在生成时倾向于选择"安全"的短序列。
反向 KL DKL(PS∥PT)D_{\text{KL}}(P_S \| P_T)DKL(PS∥PT) 倾向于集中在教师分布的一个模式上(mode-seeking),鼓励学生生成更确定性的输出,减少长度偏差。
第十一章:黑盒蒸馏------从 API 到合成数据
11.1 黑盒蒸馏的动机
11.1.1 闭源模型的挑战
问题:许多强大的 LLM(如 GPT-4、Claude)只提供 API 访问,无法获取内部表示。
黑盒蒸馏 :只使用教师模型的文本输出来进行蒸馏。
11.1.2 Alpaca 方法
Alpaca(Taori et al., 2023)是最早的黑盒蒸馏工作之一:
- 使用 GPT-3.5 生成 52K 条指令-回复对
- 使用这些数据微调 LLaMA-7B
损失函数:标准的自回归交叉熵(只在回复部分计算):
L=−∑t∈responselogPS(yt∣y<t,x)\mathcal{L} = -\sum_{t \in \text{response}} \log P_S(y_t | y_{<t}, \mathbf{x})L=−t∈response∑logPS(yt∣y<t,x)
11.1.3 Vicuna 方法
Vicuna(Chiang et al., 2023)使用 ShareGPT 数据(用户与 ChatGPT 的真实对话)来微调 LLaMA。
关键差异 :Alpaca 使用单轮 指令数据,Vicuna 使用多轮对话数据。
11.2 合成数据生成
11.2.1 Self-Instruct
算法 11.1(Self-Instruct):
1. 种子任务集:S = {少量人工编写的指令-回复对}
2. 重复 N 次:
a. 从 S 中采样几个示例
b. 使用 LLM 生成新的指令
c. 使用 LLM 生成对应的回复
d. 过滤低质量数据
e. 将新数据加入 S
3. 使用 S 微调学生模型
11.2.2 Evol-Instruct
WizardLM (Xu et al., 2023)提出进化指令:
算法 11.2(Evol-Instruct):
1. 初始指令集:I = {简单指令}
2. 重复:
a. 从 I 中选择一个指令
b. 使用 LLM 将指令"进化"为更复杂的版本:
- 增加约束条件
- 增加推理步骤
- 增加领域知识
c. 使用 LLM 生成回复
d. 将新的指令-回复对加入 I
11.3 黑盒蒸馏的理论分析
11.3.1 数据质量 vs 数量
定理 11.1(数据效率) :设教师模型的输出质量为 QTQ_TQT,学生模型在 NNN 条数据上训练后的性能为:
Acc(N)=QT⋅(1−e−λN)\text{Acc}(N) = Q_T \cdot (1 - e^{-\lambda N})Acc(N)=QT⋅(1−e−λN)
其中 λ\lambdaλ 是数据效率参数。
推论:少量高质量数据可能优于大量低质量数据。
11.3.2 教师能力的上界
定理 11.2(蒸馏的上界):学生模型的性能不可能超过教师模型(在相同任务上):
AccS≤AccT\text{Acc}_S \leq \text{Acc}_TAccS≤AccT
但实践中的例外:
- 学生模型可能在某些特定子任务上超越教师
- 数据增强和正则化可以弥补容量差距
- 教师的"错误"可能被学生修正
第十二章:思维链蒸馏与推理能力迁移
12.1 思维链蒸馏
12.1.1 动机
问题:标准蒸馏只迁移了教师的"答案",但没有迁移教师的"推理过程"。
思维链(Chain-of-Thought, CoT):教师模型在生成答案之前,先生成推理步骤。
12.1.2 CoT 蒸馏方法
算法 12.1(CoT 蒸馏):
1. 使用教师模型生成 (问题, 思维链, 答案) 三元组
2. 训练学生模型学习生成完整的思维链:
P_S(CoT, answer | question)
3. 损失函数:
L = -Σ_t log P_S(token_t | token_{<t}, question)
(对思维链和答案的所有 token 计算)
12.1.3 CoT 蒸馏的理论分析
定理 12.1(CoT 蒸馏的信息增益):相比只学习答案,CoT 蒸馏提供了额外的监督信号:
ICoT=I(answer;CoT∣question)I_{\text{CoT}} = I(\text{answer}; \text{CoT} | \text{question})ICoT=I(answer;CoT∣question)
这个互信息衡量了思维链中包含的关于答案的额外信息。
推论:对于需要推理的任务(如数学、逻辑),CoT 蒸馏的效果显著优于标准蒸馏。
12.2 推理能力迁移
12.2.1 迁移的挑战
问题:大模型的推理能力(如多步推理、类比推理)是否可以迁移到小模型?
实验观察:
- 标准蒸馏:小模型可以学习到教师的"知识",但推理能力有限
- CoT 蒸馏:小模型可以学习到教师的"推理模式",但泛化能力有限
- 迭代蒸馏:通过多轮蒸馏,可以逐步提升小模型的推理能力
12.2.2 迭代 CoT 蒸馏
算法 12.2(迭代 CoT 蒸馏):
for round = 1, ..., R:
1. 使用教师生成 CoT 数据
2. 使用 CoT 数据训练学生
3. 使用学生作为新的教师(或使用学生生成新的 CoT 数据)
4. 评估学生的推理能力
第五部分:完整可运行代码实现
第十三章:从零实现经典知识蒸馏
python
"""
经典知识蒸馏的完整实现。
包含:Hinton 蒸馏、温度缩放、各种损失函数。
"""
import numpy as np
from typing import Tuple, Optional
def softmax(z: np.ndarray, temperature: float = 1.0) -> np.ndarray:
"""温度缩放的 softmax。
Args:
z: logits (..., K)
temperature: 温度参数
Returns:
p: 概率分布 (..., K)
"""
z_scaled = z / temperature
z_max = np.max(z_scaled, axis=-1, keepdims=True)
exp_z = np.exp(z_scaled - z_max)
return exp_z / np.sum(exp_z, axis=-1, keepdims=True)
def kl_divergence(p: np.ndarray, q: np.ndarray) -> float:
"""KL 散度 D_KL(p || q)。
Args:
p: 概率分布 (..., K)
q: 概率分布 (..., K)
Returns:
kl: KL 散度
"""
# 避免 log(0)
p = np.clip(p, 1e-10, 1.0)
q = np.clip(q, 1e-10, 1.0)
return np.sum(p * np.log(p / q), axis=-1)
def cross_entropy(p: np.ndarray, q: np.ndarray) -> float:
"""交叉熵 H(p, q)。
Args:
p: 目标分布 (..., K)
q: 预测分布 (..., K)
Returns:
ce: 交叉熵
"""
q = np.clip(q, 1e-10, 1.0)
return -np.sum(p * np.log(q), axis=-1)
def hinton_distillation_loss(
teacher_logits: np.ndarray,
student_logits: np.ndarray,
labels: np.ndarray,
temperature: float = 4.0,
alpha: float = 0.1,
) -> Tuple[float, dict]:
"""Hinton 蒸馏损失。
L = α * L_hard + (1-α) * T^2 * L_soft
Args:
teacher_logits: 教师 logits (B, K)
student_logits: 学生 logits (B, K)
labels: 真实标签 (B,)
temperature: 温度参数
alpha: 硬标签权重
Returns:
loss: 总损失
info: 损失详情
"""
B, K = teacher_logits.shape
# 软标签
p_teacher = softmax(teacher_logits, temperature)
p_student = softmax(student_logits, temperature)
# 硬标签
y_onehot = np.zeros((B, K))
y_onehot[np.arange(B), labels] = 1
p_student_hard = softmax(student_logits, temperature=1.0)
# 软标签损失(KL 散度)
L_soft = np.mean(kl_divergence(p_teacher, p_student))
# 硬标签损失(交叉熵)
L_hard = np.mean(cross_entropy(y_onehot, p_student_hard))
# 总损失
loss = alpha * L_hard + (1 - alpha) * temperature**2 * L_soft
info = {
"L_hard": L_hard,
"L_soft": L_soft,
"total": loss,
"teacher_entropy": np.mean(-np.sum(p_teacher * np.log(np.clip(p_teacher, 1e-10, 1.0)), axis=-1)),
"student_entropy": np.mean(-np.sum(p_student * np.log(np.clip(p_student, 1e-10, 1.0)), axis=-1)),
}
return loss, info
def mae_distillation_loss(
teacher_logits: np.ndarray,
student_logits: np.ndarray,
) -> float:
"""MSE 蒸馏损失(logits 级别)。
Args:
teacher_logits: 教师 logits
student_logits: 学生 logits
Returns:
loss: MSE 损失
"""
return np.mean((teacher_logits - student_logits) ** 2)
def cosine_distillation_loss(
teacher_features: np.ndarray,
student_features: np.ndarray,
) -> float:
"""余弦相似度蒸馏损失。
Args:
teacher_features: 教师特征 (B, D)
student_features: 学生特征 (B, D)
Returns:
loss: 1 - 余弦相似度
"""
# 归一化
t_norm = teacher_features / (np.linalg.norm(teacher_features, axis=1, keepdims=True) + 1e-8)
s_norm = student_features / (np.linalg.norm(student_features, axis=1, keepdims=True) + 1e-8)
# 余弦相似度
cos_sim = np.sum(t_norm * s_norm, axis=1)
return np.mean(1 - cos_sim)
def demonstrate_distillation():
"""演示知识蒸馏。"""
np.random.seed(42)
print("=" * 70)
print("知识蒸馏基础演示")
print("=" * 70)
# 设置
B = 32 # batch size
K = 10 # 类别数
# 生成模拟数据
teacher_logits = np.random.randn(B, K) * 2
student_logits = np.random.randn(B, K) * 1.5
labels = np.random.randint(0, K, B)
# 温度的影响
print("\n 1. 温度对软标签的影响")
print(" " + "-" * 40)
for T in [1, 2, 4, 8, 16]:
p_teacher = softmax(teacher_logits, T)
entropy = np.mean(-np.sum(p_teacher * np.log(np.clip(p_teacher, 1e-10, 1.0)), axis=-1))
max_prob = np.mean(np.max(p_teacher, axis=1))
print(f" T={T:2d}: 熵={entropy:.4f}, 最大概率={max_prob:.4f}")
# 蒸馏损失
print("\n 2. 不同温度下的蒸馏损失")
print(" " + "-" * 40)
for T in [1, 2, 4, 8, 16]:
loss, info = hinton_distillation_loss(teacher_logits, student_logits, labels, temperature=T)
print(f" T={T:2d}: 总损失={info['total']:.4f}, "
f"硬标签={info['L_hard']:.4f}, 软标签={info['L_soft']:.4f}")
# alpha 的影响
print("\n 3. 不同 α 值的影响")
print(" " + "-" * 40)
T = 4
for alpha in [0.0, 0.1, 0.3, 0.5, 0.7, 0.9]:
loss, info = hinton_distillation_loss(teacher_logits, student_logits, labels, temperature=T, alpha=alpha)
print(f" α={alpha:.1f}: 总损失={info['total']:.4f}")
# 各种蒸馏损失对比
print("\n 4. 各种蒸馏损失对比")
print(" " + "-" * 40)
# Hinton 蒸馏
loss_hinton, _ = hinton_distillation_loss(teacher_logits, student_logits, labels, temperature=4, alpha=0.1)
# MSE 蒸馏
loss_mse = mae_distillation_loss(teacher_logits, student_logits)
# 余弦蒸馏
loss_cos = cosine_distillation_loss(teacher_logits, student_logits)
print(f" Hinton (T=4, α=0.1): {loss_hinton:.4f}")
print(f" MSE (logits): {loss_mse:.4f}")
print(f" Cosine: {loss_cos:.4f}")
# 教师-学生差异分析
print("\n 5. 教师-学生差异分析")
print(" " + "-" * 40)
p_t = softmax(teacher_logits, T=4)
p_s = softmax(student_logits, T=4)
# KL 散度
kl_ts = np.mean(kl_divergence(p_t, p_s))
kl_st = np.mean(kl_divergence(p_s, p_t))
js = np.mean(0.5 * kl_divergence(p_t, 0.5 * (p_t + p_s)) + 0.5 * kl_divergence(p_s, 0.5 * (p_t + p_s)))
print(f" D_KL(teacher||student): {kl_ts:.4f}")
print(f" D_KL(student||teacher): {kl_st:.4f}")
print(f" D_JS(teacher||student): {js:.4f}")
# 预测一致性
teacher_pred = np.argmax(teacher_logits, axis=1)
student_pred = np.argmax(student_logits, axis=1)
agreement = np.mean(teacher_pred == student_pred)
print(f" 预测一致率: {agreement:.2%}")
if __name__ == "__main__":
demonstrate_distillation()
第十四章:从零实现特征蒸馏
python
"""
特征蒸馏的完整实现。
包含:FitNets、注意力转移、特征分布对齐。
"""
import numpy as np
from typing import List, Tuple
def feature_distillation_loss(
teacher_features: np.ndarray,
student_features: np.ndarray,
method: str = "mse",
) -> float:
"""特征蒸馏损失。
Args:
teacher_features: 教师特征 (B, D_t)
student_features: 学生特征 (B, D_s)
method: 方法 ("mse", "cosine", "normalize")
Returns:
loss: 蒸馏损失
"""
if method == "mse":
# MSE 损失
return np.mean((teacher_features - student_features) ** 2)
elif method == "cosine":
# 余弦相似度损失
t_norm = teacher_features / (np.linalg.norm(teacher_features, axis=1, keepdims=True) + 1e-8)
s_norm = student_features / (np.linalg.norm(student_features, axis=1, keepdims=True) + 1e-8)
cos_sim = np.sum(t_norm * s_norm, axis=1)
return np.mean(1 - cos_sim)
elif method == "normalize":
# 归一化 MSE
t_mean, t_std = teacher_features.mean(axis=0), teacher_features.std(axis=0) + 1e-8
s_mean, s_std = student_features.mean(axis=0), student_features.std(axis=0) + 1e-8
t_norm = (teacher_features - t_mean) / t_std
s_norm = (student_features - s_mean) / s_std
return np.mean((t_norm - s_norm) ** 2)
else:
raise ValueError(f"Unknown method: {method}")
def attention_transfer_loss(
teacher_attention: np.ndarray,
student_attention: np.ndarray,
) -> float:
"""注意力转移损失。
Args:
teacher_attention: 教师注意力图 (B, H, S, S)
student_attention: 学生注意力图 (B, H', S, S)
Returns:
loss: 注意力转移损失
"""
# 计算注意力图(对 head 维度取平均)
A_teacher = np.mean(np.abs(teacher_attention), axis=1) # (B, S, S)
A_student = np.mean(np.abs(student_attention), axis=1) # (B, S, S)
# 归一化
A_teacher_norm = A_teacher / (np.linalg.norm(A_teacher.reshape(A_teacher.shape[0], -1), axis=1, keepdims=True).reshape(-1, 1, 1) + 1e-8)
A_student_norm = A_student / (np.linalg.norm(A_student.reshape(A_student.shape[0], -1), axis=1, keepdims=True).reshape(-1, 1, 1) + 1e-8)
# MSE 损失
loss = np.mean((A_teacher_norm - A_student_norm) ** 2)
return loss
def gram_matrix(features: np.ndarray) -> np.ndarray:
"""计算 Gram 矩阵。
Args:
features: 特征图 (B, C, H, W)
Returns:
G: Gram 矩阵 (B, C, C)
"""
B, C, H, W = features.shape
F = features.reshape(B, C, H * W)
G = np.matmul(F, F.transpose(0, 2, 1))
return G / (C * H * W)
def gram_distillation_loss(
teacher_features: np.ndarray,
student_features: np.ndarray,
) -> float:
"""Gram 矩阵蒸馏损失。
Args:
teacher_features: 教师特征图 (B, C_t, H, W)
student_features: 学生特征图 (B, C_s, H, W)
Returns:
loss: Gram 矩阵匹配损失
"""
G_teacher = gram_matrix(teacher_features)
G_student = gram_matrix(student_features)
# 如果通道数不同,需要匹配
if G_teacher.shape[1] != G_student.shape[1]:
min_channels = min(G_teacher.shape[1], G_student.shape[1])
G_teacher = G_teacher[:, :min_channels, :min_channels]
G_student = G_student[:, :min_channels, :min_channels]
return np.mean((G_teacher - G_student) ** 2)
def distribution_alignment_loss(
teacher_features: np.ndarray,
student_features: np.ndarray,
) -> float:
"""特征分布对齐损失。
对齐均值和方差。
Args:
teacher_features: 教师特征 (B, D)
student_features: 学生特征 (B, D)
Returns:
loss: 分布对齐损失
"""
# 均值对齐
t_mean = np.mean(teacher_features, axis=0)
s_mean = np.mean(student_features, axis=0)
mean_loss = np.mean((t_mean - s_mean) ** 2)
# 方差对齐
t_var = np.var(teacher_features, axis=0) + 1e-8
s_var = np.var(student_features, axis=0) + 1e-8
var_loss = np.mean((np.sqrt(t_var) - np.sqrt(s_var)) ** 2)
return mean_loss + var_loss
def demonstrate_feature_distillation():
"""演示特征蒸馏。"""
np.random.seed(42)
print("=" * 70)
print("特征蒸馏演示")
print("=" * 70)
B = 16
D_t = 256 # 教师特征维度
D_s = 128 # 学生特征维度
# 生成特征
teacher_features = np.random.randn(B, D_t) * 0.5
student_features = np.random.randn(B, D_s) * 0.5
# 不同蒸馏方法
print("\n 1. 特征蒸馏损失对比")
print(" " + "-" * 40)
# MSE(需要维度匹配)
teacher_matched = teacher_features[:, :D_s] # 简单截断
loss_mse = feature_distillation_loss(teacher_matched, student_features, "mse")
loss_cos = feature_distillation_loss(teacher_matched, student_features, "cosine")
loss_norm = feature_distillation_loss(teacher_matched, student_features, "normalize")
print(f" MSE: {loss_mse:.6f}")
print(f" Cosine: {loss_cos:.6f}")
print(f" Normalize MSE: {loss_norm:.6f}")
# 注意力转移
print("\n 2. 注意力转移损失")
print(" " + "-" * 40)
S = 32 # 序列长度
H_t = 8 # 教师头数
H_s = 4 # 学生头数
teacher_attn = np.random.randn(B, H_t, S, S) * 0.1
student_attn = np.random.randn(B, H_s, S, S) * 0.1
# 应用 softmax 使其成为有效的注意力权重
teacher_attn = np.exp(teacher_attn) / np.sum(np.exp(teacher_attn), axis=-1, keepdims=True)
student_attn = np.exp(student_attn) / np.sum(np.exp(student_attn), axis=-1, keepdims=True)
loss_attn = attention_transfer_loss(teacher_attn, student_attn)
print(f" 注意力转移损失: {loss_attn:.6f}")
# Gram 矩阵蒸馏
print("\n 3. Gram 矩阵蒸馏损失")
print(" " + "-" * 40)
C_t = 64
C_s = 32
H, W = 8, 8
teacher_feat_map = np.random.randn(B, C_t, H, W) * 0.5
student_feat_map = np.random.randn(B, C_s, H, W) * 0.5
loss_gram = gram_distillation_loss(teacher_feat_map, student_feat_map)
print(f" Gram 矩阵损失: {loss_gram:.6f}")
# 分布对齐
print("\n 4. 分布对齐损失")
print(" " + "-" * 40)
loss_dist = distribution_alignment_loss(teacher_matched, student_features)
print(f" 分布对齐损失: {loss_dist:.6f}")
if __name__ == "__main__":
demonstrate_feature_distillation()
第十五章:从零实现自蒸馏
python
"""
自蒸馏的完整实现。
包含:Born-Again Networks、在线自蒸馏、多视角自蒸馏。
"""
import numpy as np
from typing import List, Tuple
class SimpleModel:
"""简单的线性模型,用于演示。"""
def __init__(self, input_dim: int, output_dim: int, lr: float = 0.01):
self.W = np.random.randn(input_dim, output_dim) * 0.01
self.b = np.zeros(output_dim)
self.lr = lr
def forward(self, X: np.ndarray) -> np.ndarray:
"""前向传播。"""
return X @ self.W + self.b
def softmax(self, z: np.ndarray, temperature: float = 1.0) -> np.ndarray:
"""温度缩放的 softmax。"""
z_scaled = z / temperature
z_max = np.max(z_scaled, axis=-1, keepdims=True)
exp_z = np.exp(z_scaled - z_max)
return exp_z / np.sum(exp_z, axis=-1, keepdims=True)
def train_step(
self,
X: np.ndarray,
y: np.ndarray,
teacher_logits: np.ndarray = None,
temperature: float = 4.0,
alpha: float = 0.5,
) -> float:
"""训练步骤。
Args:
X: 输入 (B, D)
y: 标签 (B,)
teacher_logits: 教师 logits (如果有的话)
temperature: 温度
alpha: 硬标签权重
Returns:
loss: 损失
"""
B = X.shape[0]
K = self.W.shape[1]
# 前向传播
logits = self.forward(X)
# 硬标签
y_onehot = np.zeros((B, K))
y_onehot[np.arange(B), y] = 1
p_student_hard = self.softmax(logits, 1.0)
# 硬标签梯度
grad_hard = (p_student_hard - y_onehot) / B
if teacher_logits is not None:
# 软标签
p_teacher = self.softmax(teacher_logits, temperature)
p_student = self.softmax(logits, temperature)
# 软标签梯度(近似)
grad_soft = temperature * (p_student - p_teacher) / B
# 组合梯度
grad = alpha * grad_hard + (1 - alpha) * grad_soft
else:
grad = grad_hard
# 更新参数
self.W -= self.lr * X.T @ grad
self.b -= self.lr * grad.mean(axis=0)
# 计算损失
p = self.softmax(logits, 1.0)
loss = -np.mean(np.sum(y_onehot * np.log(np.clip(p, 1e-10, 1.0)), axis=-1))
return loss
def predict(self, X: np.ndarray) -> np.ndarray:
"""预测。"""
logits = self.forward(X)
return np.argmax(logits, axis=1)
def born_again_networks(
X_train: np.ndarray,
y_train: np.ndarray,
X_test: np.ndarray,
y_test: np.ndarray,
n_generations: int = 3,
n_epochs: int = 50,
temperature: float = 4.0,
alpha: float = 0.1,
) -> List[dict]:
"""Born-Again Networks 蒸馏。
Args:
X_train, y_train: 训练数据
X_test, y_test: 测试数据
n_generations: 代数
n_epochs: 每代的训练轮数
temperature: 温度
alpha: 硬标签权重
Returns:
results: 每代的结果
"""
D = X_train.shape[1]
K = len(np.unique(y_train))
results = []
teacher_logits = None
for gen in range(n_generations):
print(f"\n 第 {gen + 1} 代:")
# 创建学生模型
student = SimpleModel(D, K, lr=0.01)
# 训练
for epoch in range(n_epochs):
loss = student.train_step(
X_train, y_train,
teacher_logits=teacher_logits,
temperature=temperature,
alpha=alpha if teacher_logits is not None else 1.0,
)
# 评估
train_pred = student.predict(X_train)
test_pred = student.predict(X_test)
train_acc = np.mean(train_pred == y_train)
test_acc = np.mean(test_pred == y_test)
print(f" 训练准确率: {train_acc:.2%}")
print(f" 测试准确率: {test_acc:.2%}")
results.append({
"generation": gen + 1,
"train_acc": train_acc,
"test_acc": test_acc,
})
# 将学生作为下一代的教师
teacher_logits = student.forward(X_train)
return results
def online_self_distillation(
X_train: np.ndarray,
y_train: np.ndarray,
X_test: np.ndarray,
y_test: np.ndarray,
n_epochs: int = 100,
ema_decay: float = 0.999,
temperature: float = 4.0,
distill_weight: float = 0.5,
) -> dict:
"""在线自蒸馏(EMA 教师)。
Args:
X_train, y_train: 训练数据
X_test, y_test: 测试数据
n_epochs: 训练轮数
ema_decay: EMA 衰减率
temperature: 温度
distill_weight: 蒸馏权重
Returns:
results: 训练结果
"""
D = X_train.shape[1]
K = len(np.unique(y_train))
# 主模型
model = SimpleModel(D, K, lr=0.01)
# EMA 模型(教师)
ema_model = SimpleModel(D, K, lr=0.0) # 不更新
ema_model.W = model.W.copy()
ema_model.b = model.b.copy()
train_accs = []
test_accs = []
for epoch in range(n_epochs):
# 训练主模型(使用 EMA 模型作为教师)
ema_logits = ema_model.forward(X_train)
loss = model.train_step(
X_train, y_train,
teacher_logits=ema_logits,
temperature=temperature,
alpha=1 - distill_weight,
)
# 更新 EMA 模型
ema_model.W = ema_decay * ema_model.W + (1 - ema_decay) * model.W
ema_model.b = ema_decay * ema_model.b + (1 - ema_decay) * model.b
# 评估
if (epoch + 1) % 10 == 0:
train_pred = model.predict(X_train)
test_pred = model.predict(X_test)
train_acc = np.mean(train_pred == y_train)
test_acc = np.mean(test_pred == y_test)
train_accs.append(train_acc)
test_accs.append(test_acc)
print(f" Epoch {epoch + 1}: 训练={train_acc:.2%}, 测试={test_acc:.2%}")
return {
"train_accs": train_accs,
"test_accs": test_accs,
"final_train_acc": train_accs[-1],
"final_test_acc": test_accs[-1],
}
def demonstrate_self_distillation():
"""演示自蒸馏。"""
np.random.seed(42)
print("=" * 70)
print("自蒸馏演示")
print("=" * 70)
# 生成数据
n_train = 500
n_test = 100
D = 32
K = 5
X_train = np.random.randn(n_train, D)
y_train = np.random.randint(0, K, n_train)
X_test = np.random.randn(n_test, D)
y_test = np.random.randint(0, K, n_test)
# Born-Again Networks
print("\n 1. Born-Again Networks")
print(" " + "-" * 40)
ban_results = born_again_networks(
X_train, y_train, X_test, y_test,
n_generations=3, n_epochs=50, temperature=4.0, alpha=0.1,
)
print(f"\n BAN 结果汇总:")
for r in ban_results:
print(f" 第 {r['generation']} 代: 测试准确率 = {r['test_acc']:.2%}")
# 在线自蒸馏
print("\n\n 2. 在线自蒸馏 (EMA 教师)")
print(" " + "-" * 40)
online_results = online_self_distillation(
X_train, y_train, X_test, y_test,
n_epochs=100, ema_decay=0.99, temperature=4.0, distill_weight=0.3,
)
print(f"\n 在线自蒸馏最终结果:")
print(f" 训练准确率: {online_results['final_train_acc']:.2%}")
print(f" 测试准确率: {online_results['final_test_acc']:.2%}")
# 基线对比
print("\n\n 3. 基线对比(无蒸馏)")
print(" " + "-" * 40)
baseline = SimpleModel(D, K, lr=0.01)
for epoch in range(100):
baseline.train_step(X_train, y_train)
train_pred = baseline.predict(X_train)
test_pred = baseline.predict(X_test)
print(f" 训练准确率: {np.mean(train_pred == y_train):.2%}")
print(f" 测试准确率: {np.mean(test_pred == y_test):.2%}")
if __name__ == "__main__":
demonstrate_self_distillation()
第十六章:完整蒸馏 Pipeline 与精度对比
python
"""
完整的蒸馏 Pipeline。
对比各种蒸馏方法在不同设置下的效果。
"""
import numpy as np
from typing import Dict, List, Tuple
def softmax(z, T=1.0):
z = z / T
z_max = np.max(z, axis=-1, keepdims=True)
exp_z = np.exp(z - z_max)
return exp_z / np.sum(exp_z, axis=-1, keepdims=True)
def run_full_distillation_comparison():
"""运行完整的蒸馏方法对比。"""
np.random.seed(42)
print("=" * 70)
print("知识蒸馏方法综合对比")
print("=" * 70)
# 生成数据
n_train = 1000
n_test = 200
D = 64
K = 10
X_train = np.random.randn(n_train, D)
y_train = np.random.randint(0, K, n_train)
X_test = np.random.randn(n_test, D)
y_test = np.random.randint(0, K, n_test)
# 教师模型(更大的模型)
teacher_W = np.random.randn(D, K) * 0.1
teacher_b = np.zeros(K)
# 教师预测
teacher_logits_train = X_train @ teacher_W + teacher_b
teacher_logits_test = X_test @ teacher_W + teacher_b
teacher_pred = np.argmax(teacher_logits_test, axis=1)
teacher_acc = np.mean(teacher_pred == y_test)
print(f"\n 教师模型准确率: {teacher_acc:.2%}")
print(f" 教师参数量: {D * K + K}")
# 定义训练函数
def train_student(
X_train, y_train, teacher_logits=None,
temperature=4.0, alpha=0.5, n_epochs=100, lr=0.01
):
D = X_train.shape[1]
K = len(np.unique(y_train))
W = np.random.randn(D, K) * 0.01
b = np.zeros(K)
for epoch in range(n_epochs):
logits = X_train @ W + b
B = X_train.shape[0]
# 硬标签
y_onehot = np.zeros((B, K))
y_onehot[np.arange(B), y_train] = 1
p_hard = softmax(logits, 1.0)
grad_hard = (p_hard - y_onehot) / B
if teacher_logits is not None:
p_teacher = softmax(teacher_logits, temperature)
p_student = softmax(logits, temperature)
grad_soft = temperature * (p_student - p_teacher) / B
grad = alpha * grad_hard + (1 - alpha) * grad_soft
else:
grad = grad_hard
W -= lr * X_train.T @ grad
b -= lr * grad.mean(axis=0)
return W, b
# 测试不同方法
methods = {}
# 1. 无蒸馏(基线)
W, b = train_student(X_train, y_train, n_epochs=100)
pred = np.argmax(X_test @ W + b, axis=1)
methods["无蒸馏(基线)"] = np.mean(pred == y_test)
# 2. 标准蒸馏
for T in [2, 4, 8]:
for alpha in [0.1, 0.5]:
W, b = train_student(
X_train, y_train,
teacher_logits=teacher_logits_train,
temperature=T, alpha=alpha, n_epochs=100
)
pred = np.argmax(X_test @ W + b, axis=1)
name = f"蒸馏 T={T}, α={alpha}"
methods[name] = np.mean(pred == y_test)
# 3. MSE 蒸馏
def train_mse(X_train, y_train, teacher_logits, n_epochs=100, lr=0.01, distill_weight=0.5):
D = X_train.shape[1]
K = teacher_logits.shape[1]
W = np.random.randn(D, K) * 0.01
b = np.zeros(K)
for epoch in range(n_epochs):
logits = X_train @ W + b
B = X_train.shape[0]
y_onehot = np.zeros((B, K))
y_onehot[np.arange(B), y_train] = 1
p_hard = softmax(logits, 1.0)
grad_hard = (p_hard - y_onehot) / B
grad_mse = 2 * (logits - teacher_logits) / B
grad = (1 - distill_weight) * grad_hard + distill_weight * grad_mse
W -= lr * X_train.T @ grad
b -= lr * grad.mean(axis=0)
return W, b
W, b = train_mse(X_train, y_train, teacher_logits_train, distill_weight=0.5)
pred = np.argmax(X_test @ W + b, axis=1)
methods["MSE 蒸馏"] = np.mean(pred == y_test)
# 打印结果
print(f"\n {'方法':>25} {'测试准确率':>12} {'相对提升':>12}")
print(f" {'-'*25} {'-'*12} {'-'*12}")
baseline_acc = methods["无蒸馏(基线)"]
for name, acc in methods.items():
improvement = (acc - baseline_acc) / baseline_acc * 100
print(f" {name:>25} {acc:>12.2%} {improvement:>+11.1f}%")
# 最佳方法
best_method = max(methods.items(), key=lambda x: x[1])
print(f"\n 最佳方法: {best_method[0]} (准确率: {best_method[1]:.2%})")
# 温度敏感性分析
print(f"\n 温度敏感性分析 (α=0.1):")
print(f" {'温度':>8} {'准确率':>12} {'教师熵':>12} {'学生熵':>12}")
print(f" {'-'*8} {'-'*12} {'-'*12} {'-'*12}")
for T in [1, 2, 4, 8, 16, 32]:
W, b = train_student(
X_train, y_train,
teacher_logits=teacher_logits_train,
temperature=T, alpha=0.1, n_epochs=100
)
pred = np.argmax(X_test @ W + b, axis=1)
acc = np.mean(pred == y_test)
p_t = softmax(teacher_logits_train, T)
p_s = softmax(X_train @ W + b, T)
t_entropy = np.mean(-np.sum(p_t * np.log(np.clip(p_t, 1e-10, 1.0)), axis=-1))
s_entropy = np.mean(-np.sum(p_s * np.log(np.clip(p_s, 1e-10, 1.0)), axis=-1))
print(f" {T:>8} {acc:>12.2%} {t_entropy:>12.4f} {s_entropy:>12.4f}")
if __name__ == "__main__":
run_full_distillation_comparison()
附录:关键公式汇总
A.1 散度与距离
| 公式 | 表达式 |
|---|---|
| KL 散度 | DKL(P∣Q)=∑iP(i)logP(i)Q(i)D_{\text{KL}}(P | Q) = \sum_i P(i) \log \frac{P(i)}{Q(i)}DKL(P∣Q)=∑iP(i)logQ(i)P(i) |
| 交叉熵 | H(P,Q)=−∑iP(i)logQ(i)H(P, Q) = -\sum_i P(i) \log Q(i)H(P,Q)=−∑iP(i)logQ(i) |
| JS 散度 | DJS(P∣Q)=12DKL(P∣M)+12DKL(Q∣M)D_{\text{JS}}(P | Q) = \frac{1}{2} D_{\text{KL}}(P | M) + \frac{1}{2} D_{\text{KL}}(Q | M)DJS(P∣Q)=21DKL(P∣M)+21DKL(Q∣M) |
| 关系 | DKL(P∣Q)=H(P,Q)−H(P)D_{\text{KL}}(P | Q) = H(P, Q) - H(P)DKL(P∣Q)=H(P,Q)−H(P) |
A.2 温度缩放
| 公式 | 表达式 |
|---|---|
| 温度 softmax | pi(T)=ezi/T∑jezj/Tp_i^{(T)} = \frac{e^{z_i/T}}{\sum_j e^{z_j/T}}pi(T)=∑jezj/Tezi/T |
| 熵与温度 | H(T)H(T)H(T) 单调递增,H(0)=0H(0) = 0H(0)=0,H(∞)=logKH(\infty) = \log KH(∞)=logK |
| 梯度 | ∂pi∂zi=1Tpi(1−pi)\frac{\partial p_i}{\partial z_i} = \frac{1}{T} p_i (1 - p_i)∂zi∂pi=T1pi(1−pi) |
A.3 蒸馏损失
| 方法 | 损失函数 |
|---|---|
| Hinton 蒸馏 | L=αH(y,PS(1))+(1−α)T2DKL(PT(T)∣PS(T))\mathcal{L} = \alpha H(\mathbf{y}, P_S^{(1)}) + (1-\alpha) T^2 D_{\text{KL}}(P_T^{(T)} | P_S^{(T)})L=αH(y,PS(1))+(1−α)T2DKL(PT(T)∣PS(T)) |
| MSE 蒸馏 | L=∣zT−zS∣2\mathcal{L} = |\mathbf{z}_T - \mathbf{z}_S|^2L=∣zT−zS∣2 |
| 余弦蒸馏 | L=1−cos(hT,hS)\mathcal{L} = 1 - \cos(\mathbf{h}_T, \mathbf{h}_S)L=1−cos(hT,hS) |
| 注意力转移 | L=∣AT−AS∣F2\mathcal{L} = |A_T - A_S|_F^2L=∣AT−AS∣F2 |
| Gram 矩阵 | L=∣G(AT)−G(AS)∣F2\mathcal{L} = |G(\mathbf{A}_T) - G(\mathbf{A}_S)|_F^2L=∣G(AT)−G(AS)∣F2 |
A.4 序列级蒸馏
| 公式 | 表达式 |
|---|---|
| 序列级 KL | $D_{\text{KL}}^{\text{seq}} = \sum_{t=1}^{L} D_{\text{KL}}(P_T(\cdot |
| 反向 KL | $D_{\text{KL}}(P_S | P_T) = \sum_t \sum_v P_S(v |
A.5 自蒸馏
| 方法 | 说明 |
|---|---|
| Born-Again Networks | 多代蒸馏,每代用上一代作为教师 |
| EMA 自蒸馏 | 使用指数移动平均作为教师 |
| 多视角自蒸馏 | 不同增强版本相互对齐 |
参考文献
- Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the knowledge in a neural network. NeurIPS Workshop.
- Romero, A., et al. (2015). FitNets: Hints for thin deep nets. ICLR.
- Zagoruyko, S., & Komodakis, N. (2017). Paying more attention to attention: Improving the performance of convolutional neural networks via attention transfer. ICLR.
- Furlanello, T., et al. (2018). Born-again neural networks. ICML.
- Park, W., et al. (2019). Relational knowledge distillation. CVPR.
- Sanh, V., et al. (2019). DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter. NeurIPS Workshop.
- Jiao, X., et al. (2020). TinyBERT: Distilling BERT for natural language understanding. EMNLP.
- Gu, Y., et al. (2024). MiniLLM: Knowledge distillation of large language models. ICLR.
- Taori, R., et al. (2023). Stanford Alpaca: An instruction-following LLaMA model. GitHub.
- Chiang, W., et al. (2023). Vicuna: An open-source chatbot impressing GPT-4 with 90% ChatGPT quality. Blog.
- Xu, C., et al. (2023). WizardLM: Empowering large language models to follow complex instructions. arXiv.
- Wei, J., et al. (2022). Chain-of-thought prompting elicits reasoning in large language models. NeurIPS.