AI精炼术:利用PyTorch实现MNIST数据集上的知识蒸馏

本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!

beginning

上期给大家介绍了知识蒸馏的核心原理,时间有点久不知道大家还记不记得。一句话概括就是------将教师模型的知识通过soft targets传递给轻量化的学生模型,从而提升学生模型性能,减少计算需求 。还没看过或者忘记的小伙伴赶紧来看看叭➡从教师到学生:神奇的"知识蒸馏"之旅------原理详解篇。明白了原理之后,咱们今天就来实战一下,看看教师模型、学生模型是怎样用代码构建的,学习如何用知识蒸馏来提高学生模型的性能,让大家对蒸馏有一个更加直观的感受。除此之外,上期还有一些额外的知识蒸馏知识点没讲完,这次也一口气介绍完叭。废话不多说啦,如果你也对此感兴趣,想动手实现知识蒸馏看看效果,让我们一起愉快的学习叭🎈🎈🎈

1.知识蒸馏代码实战

在介绍代码之前呢,给大家分享两个好用的知识蒸馏代码库:

第一个开源库包括剪枝、蒸馏、神经架构搜索和量化 ;第二个是大神发表的RepDistiller,里面有12种用pytorch实现的流行知识蒸馏算法。都对知识蒸馏的学习很有帮助滴🌈🌈🌈

1.1不同温度下softmax可视化

通过上期的学习,咱们知道了蒸馏温度T越高,soft targets就越soft,所以温度是至关重要滴,那首先咱们就来学着画一下不同温度对于softmax的影响叭

  1. 导入工具包:🎈🎈🎈
python 复制代码
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

%config InlineBackend.figure_format = 'retina'
  1. 输入各类别的logits:🎈🎈🎈
python 复制代码
logits = np.array([-5,2,7,9])

4个类别的logit,你可以理解成是神经网络最后一层的线性分类层输出的4个类别的logit,它们有正有负有大有小。

  1. 普通softmax( T=1 ):🎈🎈🎈
python 复制代码
    softmax_1=np.exp(logits) / sum(np.exp(logits))
    softmax_1
    plt.plot(softmax_1,label='softmax_1')
    plt.legend()
    plt.show()

普通的softmax是蒸馏温度等于1,softmax_1=np.exp(logits) / sum(np.exp(logits))代表着把 <math xmlns="http://www.w3.org/1998/Math/MathML"> e − 5 + e 2 + e 7 + e 9 e^{-5}+e^2+e^7+e^9 </math>e−5+e2+e7+e9作为分母, <math xmlns="http://www.w3.org/1998/Math/MathML"> e − 5 、 e 2 、 e 7 、 e 9 e^{-5}、e^2、e^7、e^9 </math>e−5、e2、e7、e9分别作为分子算出来的各个数值,其代表了每一个softmax的后验概率。此时画出的图如下

  1. 知识蒸馏softmax( T=3 ):🎈🎈🎈
python 复制代码
plt.plot(softmax_1,label='T=1')

T=3
softmax_3 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_3,label='T=3')

T=5
softmax_5 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_5,label='T=5')

T=10
softmax_10 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_10,label='T=10')

T=100
softmax_100 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_100,label='T=100')

plt.xticks(np.arange(4), ['Cat', 'Dog','Donkey','Horse'])
plt.legend()
plt.show()

分别尝试温度T=3,T=5,T=10,T=100,画出它们的图如下所示,可以发现T越大,soft targets越soft,贫富差距就越小;T越小,两极分化就越大。所以T的选取很重要,若是过小的话就和没有蒸馏是一样的,过大又会陷入平均主义。

1.2载入数据集

下面就以MNIST数据集为例,利用pytorch从头训练教师网络、从头训练学生网络,并用知识蒸馏训练学生网络比较性能

  1. 导入工具包:🎈🎈🎈
python 复制代码
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transform
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
python 复制代码
#设置随机种子,便于复现
torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
python 复制代码
#使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True
  1. 载入MNIST数据集:🎈🎈🎈
python 复制代码
#载入数据集
train_dataset = torchvision.datasets.MNIST(
    root="dataset/",
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

#载入测试集
test_dataset = torchvision.datasets.MNIST(
    root="dataset/",
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

#生成dataloader
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

先导入工具包,一般的代码是要放在云gpu上的(最好是);然后载入训练集和测试集,生成训练集的DataLoader和测试集的DataLoader。

1.3构建并训练教师模型

构建教师模型:🎈🎈🎈

python 复制代码
class TeacherModel(nn.Module):
    def __init__(self, in_channels=1,num_classes=10):
        super(TeacherModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784,1200)
        self.fc2 = nn.Linear(1200,1200)
        self.fc3 = nn.Linear(1200,num_classes)
        self.dropout = nn.Dropout(p=0.5)
        
    def forward(self, x):
        x = x.view(-1,784)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)
        
        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)
        
        x = self.fc3(x)
        
        return x

构造一个教师网络,这个教师网络有三层隐含层,每一层都加了dropout,防止过拟合。第一层是把输入的MNIST中784个像素映射到1200个神经元,第二层是把1200个神经元映射成1200个神经元,第三层是把1200个神经元映射成10个类别。

从头训练教师模型:🎈🎈🎈

python 复制代码
model = TeacherModel()
model = model.to(device)
summary(model)
python 复制代码
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
python 复制代码
epochs = 6
for epoch in range(epochs):
    model.train()
    
    #训练集上训练模型权重
    for data, targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        
        #前向预测
        preds = model(data)
        loss = criterion(preds, targets)
        
        #反向传播,优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    #测试集上评估模型性能
    model.eval()
    num_correct = 0
    num_samples = 0
    
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)
            
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()
        
    model.train()
    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1, acc))
python 复制代码
teacher_model = model

首先指定一个交叉熵分类损失函数CrossEntropyLoss,指定优化器和学习率,开始训练6轮,每一次训练都是先前向再反向,每一轮之后再在测试集上评估模型的性能。运行之后看到如下所示的结果,准确率为0.9762(PS:其实这些代码都是很简单的基础知识,在前面的学习中也详解讲过啦,这里就不再细说辽)🌞🌞🌞

1.4构建并训练学生模型

构建学生模型:🎈🎈🎈

python 复制代码
class StudentModel(nn.Module):
    def __init__(self, in_channels=1,num_classes=10):
        super(StudentModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784,20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_classes)
        
    def forward(self, x):
        x = x.view(-1,784)
        x = self.fc1(x)
        x = self.relu(x)
        
        x = self.fc2(x)
        x = self.relu(x)
        
        x = self.fc3(x)
        
        return x

构建的学生模型就要小得多啦,它的每一层只有20个神经元,构建方法和上面的一样。

从头训练学生模型:🎈🎈🎈

python 复制代码
model = StudentModel()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
python 复制代码
epochs = 3
for epoch in range(epochs):
    model.train()
    
    #训练集上训练模型权重
    for data, targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        
        #前向预测
        preds = model(data)
        loss = criterion(preds, targets)
        
        #反向传播,优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    #测试集上评估模型性能
    model.eval()
    num_correct = 0
    num_samples = 0
    
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)
            
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()
        
    model.train()
    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1, acc))
python 复制代码
student_model_scratch = model

从头训练学生模型和上面训练教师模型也是一样的,最后运行得到的结果如下,准确率只有0.8986,所以我们要用知识蒸馏来训练学生模型,提高它的性能。

1.5知识蒸馏训练学生模型

知识蒸馏训练学生模型:🎈🎈🎈

python 复制代码
#准备预训练好的教师模型
teacher_model.eval()

#准备新的学生模型
model = StudentModel()
model = model.to(device)
model.train()

#蒸馏温度
temp = 7
python 复制代码
#hard_loss
hard_loss = nn.CrossEntropyLoss()
#hard_loss 权重
alpha = 0.3

# soft_loss
soft_loss = nn.KLDivLoss(reduction="batchmean")
optimizer = torch.optim.Adam(model.paramaters(), lr=1e-4)
python 复制代码
epochs = 3
for epoch in range(epochs):
    
    #训练集上训练模型权重
    for data, targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        
        #教师模型预测
        with torch.no_grad():
            teacher_preds = teacher_model(data)
        
        #学生模型预测
        student_preds = model(data)
        #计算hard_loss
        student_loss = hard_loss(student_preds, targets)
        
        #计算蒸馏后的预测结果及soft_loss
        ditillation_loss = soft_loss(
            F.softmax(student_preds / temp, dim=1),
            F.softmax(teacher_preds / temp, dim=1)
        )
        #将hard_loss和soft_loss加权求和
        loss = alpha * student_loss + (1-alpha) * ditillation_loss
        
        #反向传播,优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    #测试集上评估模型性能
    model.eval()
    num_correct = 0
    num_samples = 0
    
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)
            
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()
        
    model.train()
    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1, acc))

蒸馏温度选为7,hard_loss是一个普通的分类交叉熵损失函数,而soft_loss是一个KL散度(差不多也是交叉熵损失函数)。训练时也是先前向后反向,前向时先获取教师网络的预测结果,对教师网络预测结果进行蒸馏和softmax,然后把学生网络温度为temp和教师网络温度为temp时分别算出来softmax,一起作为soft_loss,算出来一个总的损失函数 loss = alpha * student_loss + (1-alpha) * ditillation_loss。其余反向传播、评估性能等步骤和刚刚是一样的。运行得到的结果如下,可以看到准确率相比没蒸馏前有了提升(虽然提升不大,但这只是一个小demo,具体的还要进行调参优化)🌟🌟🌟

其实,我们并不能用最后的分数来衡量它知识蒸馏是好是坏,因为知识蒸馏并不只是能涨点,并不只是能压缩模型提高性能,它还有很多潜在的好处,比如我们可以用海量的无监督的大数据集,可以防止过拟合,可以实现知识从大模型到小模型的迁移,这才是关于知识蒸馏我们要把握的点✨✨✨

2.知识蒸馏的补充知识

<math xmlns="http://www.w3.org/1998/Math/MathML"> 知识蒸馏为什么 w o r k ??? \color{blue}{知识蒸馏为什么work???} </math>知识蒸馏为什么work???

学完了知识蒸馏的原理与代码实现之后,小伙伴们有没有认真想一想知识蒸馏为什么会有用呢🧐🧐🧐比较让人信服的一个机理解释是:

如上图,绿色是教师网络的求解空间,因为教师网络比较大嘛,所以它的表达能力和拟合能力比较强;学生网络是比较小的蓝色区域,它的表达能力比较差,求解空间比较小。训练教师网络之后,假如教师网络收敛到了红圈里面,如果我们单独训练学生网络(不蒸馏,直接用原来的数据集和标签),那么学生网络会收敛到黄色区域。毫无疑问,此时的学生网络和教师网络是有一定距离的,如果单纯的用hard label来训练学生网络,它是没法达到教师网络的水平滴 ;但我们加上知识蒸馏(橙色区域)之后,教师网络就会引导这个黄圈,告诉它怎么去收敛,那么它最终会收敛到这个橙圈里,而橙圈是教师网络的一个子集,它离原生的学生网络的收敛空间更接近,离教师网络越近,效果就越好😁😁😁

<math xmlns="http://www.w3.org/1998/Math/MathML"> 知识蒸馏与迁移学习??? \color{blue}{知识蒸馏与迁移学习???} </math>知识蒸馏与迁移学习???

学完知识蒸馏后,小伙伴们有没有感觉和迁移学习很像,毕竟知识蒸馏是从教师模型迁移到学生模型上的,那它俩到底是一个什么关系腻 🧐🧐🧐其实,知识蒸馏和迁移学习是没关系滴 ,它俩的概念是正交的,迁移学习指的是把一个领域训练的模型,让其泛化到另一个领域,比如说用X胸片的数据集去训练一个原本识别猫狗的模型,然后猫狗模型就慢慢学会去分辨x光胸片的各种病,这种把猫狗域迁移到了医疗域属于迁移学习(侧重于领域的迁移 );而知识蒸馏是把一个模型的知识迁移到另一个模型上,通常是大模型迁移到小模型(侧重于模型的迁移)。所以这俩是可以交叉的,可以用知识蒸馏实现迁移学习......也可以完全没有任何关系


ending

看到这里相信盆友们都对如何用代码实现知识蒸馏有了一个全面深入的了解啦,小伙伴们学废了没呀👀很开心能把学到的知识以文章的形式分享给大家🌴🌴🌴如果你也觉得我的分享对你有所帮助,please一键三连嗷!!!下期见

相关推荐
NAGNIP10 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab11 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab11 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP15 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年15 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼15 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS15 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区16 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈16 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang17 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx