从教师到学生:神奇的“知识蒸馏”之旅——原理详解篇

本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!

beginning

在深度学习模型训练与部署的过程中,小伙伴们有没有遇到过因模型太大而导致训练部署困难呢🧐🧐🧐所以今儿就给大家介绍一个人工智能领域重要的模型压缩方法------知识蒸馏。知识蒸馏基于"教师-学生网络思想"的训练方法,通过师徒传授,将大规模教师模型的知识传递给轻量化学生网络,实现模型压缩和部署。它不仅在计算机视觉上有所利用,而且在NLP、多模态学习、预训练大模型领域都有着广泛应用。大家都熟悉蒸馏是把不纯净的水通过加热、冷凝变成纯净水的过程,那如何对知识进行蒸馏萃取呢?别急别急,如果你也对知识蒸馏有点子感兴趣,想透视知识蒸馏机理,一睹它的风采,让我们一起愉快的学习叭🎈🎈🎈

1.知识蒸馏简介

知识蒸馏其实水很深,因为它是一个人工智能的通用方法,所以需要我们具备一些机器学习、深度学习的基础知识,比如什么是交叉熵损失函数、softmax操作、模型如何通过梯度下降和反向传播来训练的 等,如果你对这些知识并不是很熟悉的话,这里推荐一个讲的超棒的视频喔➡[双语字幕]吴恩达深度学习🍭🍭🍭

大家都知道蒸馏是把水里的杂质去掉,变成清纯的蒸馏水,那知识蒸馏也是干的同一件事儿,只是把一件大的模型,我们称之为教师模型,将里面的知识给萃取蒸馏出来,浓缩到一个小的学生模型上。你可以理解为一个大的、臃肿的、性能高的教师神经网络,把他的知识教给了一个小的、轻量化的学生神经网络,这里有一个知识的迁移,知识就从教师网络迁移到学生网络上🧸🧸🧸

就像上图所示,教师网络把知识传递给学生网络,这一过程称之为蒸馏Distill或者迁移Transfer。假设这个老师教的很好,学生掌握了老师的各种安身立命的技能,那这个学生网络就可以用轻量化的身段来取代这个老师(俗话说教会徒弟饿死师父不是没有道理滴👀)。

接下来可能就有小伙伴要问了:为啥要把网络弄得那么小嘞?为啥要训练一个学生网络,不直接用教师网络呢?答案是大模型太臃肿 。在现代的人工智能社会,各种AI算法(计算机视觉、语音识别、预训练大模型等等)都是很大的,而真正落地应用终端的算力又是非常有限的,比如手机移动端,智能穿戴,电脑电视,无人驾驶汽车,安防监控等等,你想想一个手机的算力能有多大腻。教师网络也许是用海量的算力能源,在数据中心花了大量的电力和成本资源才训练出来的臃肿模型,现在呢我们要把它部署在算力有限的终端设备上,在这种场景下,所以我们需要把大模型变成小模型,把小模型部署在移动终端上,这也是知识蒸馏的核心目的🌞🌞🌞

2.知识蒸馏核心原理

知识蒸馏的这篇论文大家一定要读一读,不仅开创了知识蒸馏这一新领域,而且因为是几位巨佬写的(其中就有深度学习之父Hinton),所以写作手法遣词造句都很值得我们借鉴。就比如在论文的开头先写昆虫在不同时期的需求是不一样的,进而引到机器学习领域------我们训练和部署往往用的是同一套模型,而训练的目标是从数据集中学习到海量的规律,部署是要足够的快足够的轻量化,所以这二者会有一个矛盾。作者从大自然延伸到机器学习领域,思路非常巧妙哈✨✨✨下面就详细介绍一下知识蒸馏的核心思想,如何让教师网络把知识教给学生网络。

2.1知识的表示与迁移

假如我们把上图马的照片喂给一个神经网络或者一个图像分类模型,那么结果就会有很多类别,每个类别都会给出一个是马、是猫、是狗、是驴、是车等的概率。我们训练网络的时候只是告诉网络这是一匹马,至于说它是不是驴、是不是汽车这些概率一律为0,这被称为hard targets (如下左图)我们是用hard targets来训练网络的。但是小伙伴们仔细想一下,这足够科学吗?这其实是不够科学的,这样的标签就等同于告诉网络这就是一匹马,这不是驴不是车,而且它不是驴不是马的概率是相等的。我们肉眼明明能够看出来它跟驴子是有些相似性的,它更像驴子更不像汽车,所以无论从图像分类还是生物学的角度看这总归是不科学的🌻🌻🌻

与hard targets不同,当我们把图片送入神经网络后,soft targets会给出不一样的结果:马 0.7,驴 0.25,汽车 0.05,这不仅告诉我们它是马的概率是最大的0.7,是汽车的概率是最小的0.05,还有0.25的可能性是驴呢,说明soft targets包含了更多的知识 。所以我们在训练教师网络的时候可以用hard targets训练,训练出来教师网络之后,教师网络对这张图片的预测结果,即soft targets能够传递出更多的信息,那么我们就可以用soft targets去训练学生网络

就像上图中的右图所示,这个数字为2的可能性无疑是最大的,但又有点像3,其他类别的概率比较小,同时包含了非正确类别概率的相对大小。如果单纯用hard targets的话,就把它像谁、不像谁、有多像、有多不像的信息给抹去了,所以hard targets是不科学的,soft targets是比较科学的⛳⛳⛳

现在我们的任务就明确了,我们要用教师网络预测出的soft targets作为训练学生网络的标签 。这时候还要进行一个操作------现在的soft targets我觉得它还不够soft,我想让它更soft,想让其他类别的概率也变大,把它们的相对大小充分的暴露出来,夸张出来,放大出来,让学生网络对这些非正确类别概率的信息有更强烈的信号。所以这时候我们要引入一个蒸馏温度T,温度T越高,soft targets就越soft✨✨✨

2.2蒸馏温度T

怎么引入T呢?其实就是在原来的softmax操作基础上除T: <math xmlns="http://www.w3.org/1998/Math/MathML"> q i = exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) {q_i} = \frac{{\exp ({z_i}/T)}}{{\sum\nolimits_j {\exp ({z_j}/T)} }} </math>qi=∑jexp(zj/T)exp(zi/T)

如果T等于1的话就是原始的softmax,若T变大的话,就变得越soft,即高的概率给它变低,低的概率给它变高,但是它们的相对大小仍然是固定的,如下图所示。温度T越高,非正确类别概率的相对大小信息就暴露的更加明显,但是如果T过大,就像下图中的紫色折线,它们的差距就没有了,变成平均主义一条横线。

我们知道softmax操作本身就是把每个类别的logit强行的变成0-1之间的概率并且求和为1,所以它是有放大差异的功能。我们可以通过增加T来把原来比较hard的标签变得更soft,而更软的soft targets去训练学生网络可以获取更丰富的知识 。听到这儿小伙伴们是不是有些不理解腻🧐🧐🧐那来举个例子辅助我们理解(图片来自b站学习up主同济子豪兄):

比如上图左侧学生网络是一个神经网络,最后一层的线性分类层给出猫、狗、驴、马的logit分数分别为-5、2、7和9,那原来的softmax是怎么操作的呢?它是把 <math xmlns="http://www.w3.org/1998/Math/MathML"> e − 5 + e 2 + e 7 + e 9 e^{-5}+e^2+e^7+e^9 </math>e−5+e2+e7+e9作为分母, <math xmlns="http://www.w3.org/1998/Math/MathML"> e − 5 、 e 2 、 e 7 、 e 9 e^{-5}、e^2、e^7、e^9 </math>e−5、e2、e7、e9分别作为分子算出来四个数值,因为T=1,所以它两极分化比较严重,四个数值的差距非常大,此时如果我们把T=1变成T=3呢,即把所有的分母都添上一个3,经过这样一个操作的话,得出来的数值差距明显不那么大了,变得更软(soft)了,教师网络也是同理🌈🌈🌈

2.3知识蒸馏过程

目前最难理解的部分已经介绍完啦,下面就是一个知识蒸馏的过程,那到底教师和学生网络是怎么样来进行蒸馏学习的呢?

  1. 首先我们有一个已经训练好的教师网络,我们把很多数据喂给教师网络,针对每个数据,网络会给一个温度为t时的softmax
  2. 然后我们再把数据喂给学生网络,这个学生网络可能是还没有训练或可能训练一半,也给学生网络一个softmax(T=t)
  3. 接着我们计算教师网络softmax(T=t)和学生网络softmax(T=t)时的损失函数,希望让两者越接近越好(就是学生在模拟老师的预测结果嘛);
  4. 那么学生网络自己经过一个温度为1的softmax,即softmax(T=1),和ground truth(即hard label)再做一个损失函数,希望让它俩也接近。就是说学生网络既要兼顾T=t时,它的预测结果要和教师网络尽可能接近,也要兼顾T=1时,预测结果要和标准答案接近。
  5. 所以最终的损失函数就是由distillation loss(也称soft loss)和student loss(也称hard loss)这两项加权求和 。distillation loss相当于是有一个师父在手把手教你,告诉你这是一匹马,它更像驴更不像车,驴和车有多像有多不像;而hard loss就好比有一个课本,课本上画着马的插图,从而告诉你这就是马,这不是别的东西,所以这两项相当于有师父带和从课本看,也很符合普通人学习的过程腻(很好理解叭)🌟🌟🌟

过程就是这么个过程,接下来让我们代入数据深入理解一下叭🌈🌈🌈

首先算一下hard loss,即学生网络在T=1时的softmax和hard label做一个比较,就是一个传统的交叉熵,结果等于 <math xmlns="http://www.w3.org/1998/Math/MathML"> − l o g ( 0.88 ) -log(0.88) </math>−log(0.88),因为马这个类别的hard label为1,其他类别都为0,所以我们只需要看马类别的后验概率,取对数再加个负号就行啦。下面来看看soft loss的计算,即把学生网络T=t时候的softmax和教师网络T=t时候的softmax求一个交叉熵,计算结果就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> − [ 0.0006 × l o g ( 0.0058 ) + 0.0171 × l o g ( 0.0599 ) + 0.0466 × l o g ( 0.3170 ) + 0.9375 × l o g ( 0.6174 ) ] -[0.0006×log(0.0058)+0.0171×log(0.0599)+0.0466×log(0.3170)+0.9375×log(0.6174)] </math>−[0.0006×log(0.0058)+0.0171×log(0.0599)+0.0466×log(0.3170)+0.9375×log(0.6174)],别看这么复杂,其实这也是交叉熵损失函数,只不过刚刚只有马一个类别有标签,而现在四个类别都有标签,所以就要把标签乘在log前面,表示的是第i个样本教师网络预测出第j类的概率作为学生网络第j类的soft targets。最后的最后,我们的目标就是微调学生网络中的权重,使得最终的损失函数最小化,那具体是通过什么方法呢------通过梯度下降和反向传播。

好啦,这就是知识蒸馏的整个过程,看到这儿小伙伴们是不是恍然大悟了腻😁😁😁


ending

看到这里相信盆友们都对知识蒸馏有了一个全面深入的了解啦!原理明白了之后,下一期将开启知识蒸馏的代码篇并补充一些本期还没讲完的知识,如果感兴趣的话,请多多关注我叭🌴🌴🌴很开心能把学到的知识以文章的形式分享给大家。如果你也觉得我的分享对你有所帮助,please一键三连嗷!!!下期见

相关推荐
余生H29 分钟前
transformer.js(三):底层架构及性能优化指南
javascript·深度学习·架构·transformer
果冻人工智能1 小时前
2025 年将颠覆商业的 8 大 AI 应用场景
人工智能·ai员工
代码不行的搬运工1 小时前
神经网络12-Time-Series Transformer (TST)模型
人工智能·神经网络·transformer
石小石Orz1 小时前
Three.js + AI:AI 算法生成 3D 萤火虫飞舞效果~
javascript·人工智能·算法
罗小罗同学1 小时前
医工交叉入门书籍分享:Transformer模型在机器学习领域的应用|个人观点·24-11-22
深度学习·机器学习·transformer
孤独且没人爱的纸鹤1 小时前
【深度学习】:从人工神经网络的基础原理到循环神经网络的先进技术,跨越智能算法的关键发展阶段及其未来趋势,探索技术进步与应用挑战
人工智能·python·深度学习·机器学习·ai
阿_旭1 小时前
TensorFlow构建CNN卷积神经网络模型的基本步骤:数据处理、模型构建、模型训练
人工智能·深度学习·cnn·tensorflow
羊小猪~~1 小时前
tensorflow案例7--数据增强与测试集, 训练集, 验证集的构建
人工智能·python·深度学习·机器学习·cnn·tensorflow·neo4j
极客代码1 小时前
【Python TensorFlow】进阶指南(续篇三)
开发语言·人工智能·python·深度学习·tensorflow
zhangfeng11331 小时前
pytorch 的交叉熵函数,多分类,二分类
人工智能·pytorch·分类