深度学习进阶(六)归纳偏置与蒸馏

上一篇,我们已经完成了 Vision Transformer的完整逻辑:把图像切成 patch 当作 token,送入 Transformer Encoder 做全局建模。

但我们也提到了, ViT 存在一个绕不开的痛点:

没有足够大的数据规模,ViT 往往很难训练得好。

而用范式角度来说,这是因为ViT 本质上是一种"弱先验、强数据驱动"的建模方式。

再展开一些,对于这个问题:

为什么 ViT 需要大量数据才能表现良好?而 CNN 在小数据下却依然有效?

我们在高光谱成像的内容中已经展开过先验相关的内容,知道卷积网络把大量视觉先验写进了结构里,而 ViT 更偏"数据驱动",空间关系主要靠统计学习出来,因此对数据规模和训练配方高度敏感。

但单靠这些笼统的概念就直接进入 ViT 的下一步改进还是略显单薄。

所以本篇主要介绍 DL 中的两个重要概念:

  1. 归纳偏置(Inductive Bias)
  2. 蒸馏(Distillation)

在了解这些内容后就可以较通畅地进入 ViT 的其中一种改进逻辑。

1. 先验信息和归纳偏置

首先,我们需要明白两个高度相关但并不等价的概念:先验信息(Prior Knowledge) 与 归纳偏置(Inductive Bias)

1.1 什么是先验信息?

再简单复述一遍,用一句话定义:

先验信息是我们在学习之前就已经知道的"关于世界的规律"。

它并不来源于数据,而是来源于经验或认知

例如在视觉任务中,我们天然知道:图像是连续的、相邻像素之间更相关、物体具有结构等等基本认知,简单展开一些如下:

  1. 结构先验:
    人脸中眼睛在鼻子上方,嘴巴在鼻子下方,整体呈现稳定的空间排列关系。
  2. 局部相关性先验:
    一张图像中,相邻像素通常属于同一物体,例如一片天空区域,其颜色和纹理在局部范围内是平滑且相似的,而不会突然剧烈变化。
  3. 连续性先验:
    图像中的边缘和轮廓通常是连续的,比如一条道路或物体边界,不会在相邻位置随机中断或跳跃。

这些都属于对真实世界的描述,也就是先验信息。

1.2 什么是归纳偏置?

相比之下,归纳偏置是一个更"模型视角"的概念:

归纳偏置是模型在学习过程中"更倾向于选择某一类解"的机制,它通常来源于先验信息。

它并不是具体的知识,而是我们根据先验在模型中进行的设计 ,从而使得:模型更容易学到什么、模型更不容易学到什么。

总结来说就是:归纳偏置决定了模型的"学习方向"。

到这里,我们就可以较完善地解释最初的问题:为什么 ViT 需要大量数据才能表现良好?而 CNN 在小数据下却依然有效?

1.3 CNN 和 ViT

先用刚刚的概念来总结下: CNN 和 ViT,本质上是在"是否引入归纳偏置"这个问题上的两种不同选择。

这里要先强调一下,由归纳偏置引起的数据依赖性只是相对而言,基于 DL 的方法本身都是数据驱动的。

先说 CNN ,我们对 CNN 的建模逻辑本身就是在做一件事:把先验信息写进模型结构,而这就是归纳偏置的体现。

具体展开对照:

  1. 因为局部性先验,我们设计了卷积核,来限制模型只能关注局部区域。
  2. 我们使用网络层级结构,从局部逐步组合成全局,其实也是结构性先验的体现。

你会发现:CNN 在训练开始之前,就已经被"规定"了如何理解图像。

再从数学角度展开:CNN 的学习并不是在一个完全自由的空间中进行,而是在一个被强约束的函数空间中寻找解。

这就是强归纳偏置,直接结果就是即使数据不多,模型也能较快学到合理结构更容易收敛,不容易学偏。

但代价就是模型的表达能力被结构限制 ,这其实限制了模型的上限,因为模型的学习逻辑不一定非要按照我们人类理解的逻辑来进行。

(有段时间没用 GPT 生成中文配图了,大概四个月前生成的图中的中文还是有很多错乱的,真的是在不停地进化。)

而在 ViT 中,我们几乎做了相反的选择:尽量不把先验写进结构,而是交给数据去学习。

体现就是不使用卷积,所有 patch token 通过注意力直接全局交互,这就是在空间层次上更弱的归纳配置。

也就是说: 模型一开始并不知道"什么是局部结构",也不知道"什么是空间关系"。

因此,ViT 的学习过程是在一个几乎不受约束的巨大函数空间中搜索解。

这带来的是表达能力更强,但训练难度显著增加对数据规模和训练策略高度敏感。

久违地打个比方:

CNN 像是"带着地图找路",有大致 的方向,可以更稳定地训练寻路能力。

而 ViT 就像是"在未知环境中反复试错",大量训练后,就拥有了更强的寻路能力。
而"地图"就是归纳偏置。

1.4 小结

把这部分的内容总结为规律就是:

归纳偏置越弱,模型对数据的依赖就越强。

到这里,自然就有了下一个问题:

如果我们不想改变 ViT 的结构,又希望它在小数据上表现更好,该怎么办?

答案就是下一部分的内容:蒸馏

2. 蒸馏

如果用一句话来概括:蒸馏就是让一个"小模型"去模仿一个"大模型"的输出,从而学到更好的决策能力。

单看这句话,你可能会联想到我们之前介绍过的迁移学习。它们看起来都是在"借助一个强模型的能力",但本质逻辑是不同的:

  • 迁移学习:把"已经学到的参数"拿过来用。
  • 蒸馏 :不直接用参数,而是模仿模型的行为(输出分布)

在详细展开前,我们先统一两个核心角色:

  1. Teacher(教师模型):通常是一个性能较强、已经训练好的模型。
  2. Student(学生模型):我们真正要训练的目标模型。

这种命名也是相关领域文献内的主流称呼,下面就来展开蒸馏的思路。
要提前说明的是,蒸馏技术的逻辑是相通的,但其本身存在多种形式,在不同类型的任务中的实现方式也不同,这里我们使用最基础的分类任务来演示:

2.1 准备 Teacher

要进行蒸馏,首先要完成对 Teacher 的准备

Teacher 通常是一个已经训练好的强模型 ,最常见情况就是直接用预训练好的 Teacher

当然,如果在某些研究或特殊任务中,没有现成的强模型可用,那这时的流程就是先训练一个性能尽可能好的 Teacher,再用它去蒸馏 Student 。

你可能会觉得这种方式有些没必要,但这是因为我们通常不是在追求最强模型,而是在追求"足够强 + 足够便宜"的模型。

总之,这里的一个重要原则就是:

Teacher 不一定要"大",但一定要"比 Student 更可靠"。

不然 Student 学到的是错误知识,蒸馏反而拖累性能。

2.2 软标签 soft label

准备好 Teacher 后,第二步就是在 Teacher 上运行数据,获取软标签

不难理解,它的操作是这样的:

对于同一张输入图像, Teacher 不仅给出最终类别。还会输出一个完整的概率分布:

\[p_t = [p_1, p_2, ..., p_C] \]

这个分布包含了类别之间的相似关系信息

比如:

\[\text{cat} = 0.7,\quad \text{dog} = 0.2,\quad \text{car} = 0.1 \]

在这你会发现,这里其实就是直接获取经过输出层 softmax 得到的概率分布

还没完,在实际蒸馏中,通常会引入一个温度参数 \(T\) 来"软化"软标签和 Students 输出后再进行相关计算

\[p_i = \frac{e^{z_i / T}}{\sum_j e^{z_j / T}} \]

这样,当设计 \(T > 1\) 时,概率分布就会变得更平滑、小概率类别被"放大",让类比间的差距更明显。

实际上,它是在真实运行中经常调试的超参数

总之,我们通过 Teacher 获取了对数据集的软标签,从而得到了哪些类别"接近正确"、模型的"犹豫程度"、决策边界的大致结构这类更细节的信息。

2.3 Student 模仿 Teacher

这里就是蒸馏的核心逻辑,其实用简单的话来解释就是:

设计 Student 的损失函数,让 Student 的训练拟合目标除了学习真正标签,还学习 Teacher 输入的软标签,即拟合 Teacher 输出的概率分布。

下面的数学公式较为繁琐,先理解了这部分的逻辑后,就问题不大了:

首先要介绍的是 KL 散度 (KL divergence,全称 Kullback--Leibler divergence)。

对于离散分布,KL散度定义为:

\[D_{KL}(P \parallel Q) = \sum_{i} P(i)\log \frac{P(i)}{Q(i)} \]

而在实际实现中为了简化运算,其等价形式是:

\[\mathcal{L}_{KD} = - \sum p_t \log p_s \]

整体在蒸馏语境里通常写成:

\[D_{KL}(p_t \parallel p_s) \]

其中:

  • \(p_t\):Teacher 的分布(真实参考)。
  • \(p_s\):Student 的分布(要学习的对象)。

语义上说,KL在这里的作用就是衡量 Student 的预测分布和 Teacher 的分布之间的差异有多大。

看个实例:

类别 \(p_t\) (Teacher) \(p_s\) (Student) 比值 \(\frac{p_t}{p_s}\) \(\log(\frac{p_t}{p_s})\) \(p_t \cdot \log(\frac{p_t}{p_s})\)
Cat 0.7 0.5 1.4 0.336 0.235
Dog 0.2 0.4 0.5 -0.693 -0.139
Car 0.1 0.1 1.0 0.000 0.000

最终:

\[D_{KL}(p_t \parallel p_s) = 0.235 - 0.139 + 0 = 0.096 \]

这个结果的组成逻辑是这样的:

  1. Cat:Student 给的概率 偏低(0.5 < 0.7), 有误差
  2. Dog:Student 给的概率 偏高(0.4 > 0.2),有误差
  3. Car:没问题。

显然,结果越小,就代表两种分布越接近。

而普通分类任务中,交叉熵损失函数如下:

\[\mathcal{L}{CE} = -\sum{i=1}^{C} y_i \log p_i \]

最终,Student 的损失函数就是二者的组合:

\[\mathcal{L} = \alpha \mathcal{L}{CE} + (1 - \alpha)\mathcal{L}{KD} \]

其中,\(\alpha\) 为调节权重的超参数,两项损失分布代表:

  1. \(\mathcal{L}_{CE}\):告诉 Student "标准答案是什么"。
  2. \(\mathcal{L}_{KD}\):告诉 Student "一个更强模型是怎么思考的"。

在原始蒸馏方法中,为了补偿前面的 \(T\) 对梯度的缩放,损失还会引入 \(T^2\) 进行修正,但在现代实践中影响较小,常常省略,了解即可。

如此开始训练传播,最终便可得到蒸馏模型 Student 。

2.4 小结

其实你会发现蒸馏是一种取巧的逻辑:对一个强大的模型,我直接去学习你的答案分布。

但实际上,蒸馏确实有其理论支持和实际价值,而这里展示的也只是一种较原始的逻辑,之后我们再详细展开。

回到 ViT,我们已经知道了它的问题在于"搜索空间太自由"。那么蒸馏在这里的作用就是:

人为引入一个"软约束",缩小搜索空间,使优化更稳定,从而减少数据依赖。

这种逻辑实际上仍然是在利用 Teacher 的归纳配置。

同样的,当 Teacher 本身存在偏差时,这种约束也会间接限制性能上限,因此,对 \(\alpha\) 的调试也至关重要。

了解完两个概念后,就可以继续 ViT 的下一步改进了。

相关推荐
魏杨杨33 分钟前
一个程序员眼中的 AI 核心概念,讲透 LLM 、Agent 、MCP 、Skill 、RAG...
ai·.net·agent·claude code
RyFit1 小时前
SpringAI 常见问题及解决方案大全
java·ai
元拓数智1 小时前
智能分析落地卡壳?先补好「数据关系+语义治理」这层技术基建
大数据·分布式·ai·spark·数据关系·语义治理
企学宝2 小时前
企学宝5月专题课程丨《OpenClaw AI 智能体实战营:从零基础部署到全场景自动化落地》
人工智能·ai·企业培训
AI算法沐枫3 小时前
深度学习python代码处理科研测序数据
数据结构·人工智能·python·深度学习·决策树·机器学习·线性回归
哥布林学者4 小时前
高光谱拼接算法(一)扫推式成像和航带拼接算法
机器学习·高光谱成像
malog_5 小时前
大语言模型后训练全解析
人工智能·深度学习·机器学习·ai·语言模型
枫叶林FYL6 小时前
【强化学习】3 双系统持续强化学习:快速迁移与元知识整合架构手册
人工智能·机器学习·架构
低代码行业资讯6 小时前
五大实锤证据:AI不会终结低代码,只会倒逼技术进化
低代码·ai
神秘的土鸡6 小时前
Agent 落地:贴合健身真实场景的 AI 人物跟练方案
ai·语言模型·agent