知识蒸馏实战代码教学二(代码实战部分)

一、上章原理回顾

具体过程:

(1)首先我们要先训练出较大模型既teacher模型。(在图中没有出现)

(2)再对teacher模型进行蒸馏,此时我们已经有一个训练好的teacher模型,所以我们能很容易知道teacher模型输入特征x之后,预测出来的结果teacher_preds标签。

(3)此时,求到老师预测结果之后,我们需要求解学生在训练过程中的每一次结果student_preds标签。

(4)先求hard_loss,也就是学生模型的预测student_preds与真实标签targets之间的损失。

(5)再求soft_loss,也就是学生模型的预测student_preds与教师模型teacher_preds的预测之间的损失。

(6)求出hard_loss与soft_loss之后,求和总loss=a*hard_loss + (1-a)soft_loss,a是一个自己设置的权重参数,我在代码中设置为a=0.3。

(7)最后反向传播继续迭代。

二、代码实现

1、数据集

数据集采用的是手写数字的数据集mnist数据集,如果没有下载,代码部分中会进行下载,只需要把download改成True,然后就会保存在当前目录中。该数据集将其分成80%的训练集,20%的测试集,最后返回train_dataset和test_datatset。

python 复制代码
class MyDataset(Dataset):
    def __init__(self,opt):
        self.opt = opt

    def MyData(self):
        ## mnist数据集下载0
        mnist = datasets.MNIST(
            root='../datasets/', train=True, download=False, transform=transforms.Compose(
                [transforms.Resize(self.opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        )

        dataset_size = len(mnist)
        train_size = int(0.8 * dataset_size)
        test_size = dataset_size - train_size

        train_dataset, test_dataset = random_split(mnist, [train_size, test_size])

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.opt.batch_size,
            shuffle=True,
        )

        test_dataloader = DataLoader(
            test_dataset,
            batch_size=self.opt.batch_size,
            shuffle=False,  # 在测试集上不需要打乱顺序
        )
        return train_dataloader,test_dataloader

2、teacher模型和训练实现

(1) 首先是teacher模型构造,经过三次线性层。

python 复制代码
import torch.nn as nn
import torch

img_area = 784

class TeacherModel(nn.Module):
    def __init__(self,in_channel=1,num_classes=10):
        super(TeacherModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(img_area,1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = x.view(-1, img_area)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

(2)训练teacher模型

老师模型训练完成后其权重参数会保存在teacher.pth当中,为以后调用。

python 复制代码
import torch.nn as nn
import torch


## 创建文件夹
from tqdm import tqdm

from dist.TeacherModel import TeacherModel

weight_path = 'C:/Users/26394/PycharmProjects/untitled1/dist/params/teacher.pth'
## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.backends.cudnn.benchmark = True #使用卷积cuDNN加速


class TeacherTrainer():
    def __init__(self,opt,train_dataloader,test_dataloader):
        self.opt = opt
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader

    def trainer(self):
        # 老师模型
        opt = self.opt
        train_dataloader = self.train_dataloader
        test_dataloader = self.test_dataloader

        teacher_model = TeacherModel()
        teacher_model = teacher_model.to(device)

        criterion = nn.CrossEntropyLoss()
        optimizer_teacher = torch.optim.Adam(teacher_model.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

        for epoch in range(opt.n_epochs):  ## epoch:50
            teacher_model.train()

            for data, targets in tqdm(train_dataloader):
                data = data.to(device)
                targets = targets.to(device)

                preds = teacher_model(data)
                loss = criterion(preds, targets)

                optimizer_teacher.zero_grad()
                loss = criterion(preds, targets)
                loss.backward()
                optimizer_teacher.step()

            teacher_model.eval()
            num_correct = 0
            num_samples = 0
            with torch.no_grad():
                for x, y in test_dataloader:
                    x = x.to(device)
                    y = y.to(device)

                    preds = teacher_model(x)

                    predictions = preds.max(1).indices
                    num_correct += (predictions == y).sum()
                    num_samples += predictions.size(0)
                acc = (num_correct / num_samples).item()

            torch.save(teacher_model.state_dict(), weight_path)

        teacher_model.train()
        print('teacher: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

(3)训练teacher模型

模型参数都在paras()当中设置好了,直接调用teacher_model就行,然后将其权重参数会保存在teacher.pth当中。

python 复制代码
import argparse

import torch

from dist.DistillationTrainer import DistillationTrainer
from dist.MyDateLoader import MyDataset
from dist.TeacherTrainer import TeacherTrainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def paras():
    ## 超参数配置
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epochs", type=int, default=5, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")
    opt = parser.parse_args()
    ## opt = parser.parse_args(args=[])                 ## 在colab中运行时,换为此行
    print(opt)
    return opt


if __name__ == '__main__':
    opt = paras()
    data = MyDataset(opt)
    train_dataloader, test_dataloader = data.MyData()

    # 训练Teacher模型
    teacher_trainer = TeacherTrainer(opt,train_dataloader,test_dataloader)
    teacher_trainer.trainer()

3、学生模型的构建

学生模型也是经过了三次线性层,但是神经元没有teacher当中多。所以student模型会比teacher模型小很多。

python 复制代码
import torch.nn as nn
import torch

img_area = 784

class StudentModel(nn.Module):
    def __init__(self,in_channel=1,num_classes=10):
        super(StudentModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(img_area,20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_classes)

    def forward(self, x):
        x = x.view(-1, img_area)
        x = self.fc1(x)
        # x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        # x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

4、知识蒸馏训练

(1)首先读取teacher模型。

将teacher模型中的权重参数teacher.pth放入模型当中。

python 复制代码
 #拿取训练好的模型
        teacher_model = TeacherModel()
        if os.path.exists(weights):
            teacher_model.load_state_dict(torch.load(weights))
            print('successfully')
        else:
            print('not loading')
        teacher_model = teacher_model.to(device)

(2)设置损失求解的函数

hard_loss用的就是普通的交叉熵损失函数,而soft_loss就是用的KL散度。

python 复制代码
        # hard_loss
        hard_loss = nn.CrossEntropyLoss()
        # hard_loss权重
        alpha = 0.3

        # soft_loss
        soft_loss = nn.KLDivLoss(reduction="batchmean")

(3)之后再进行蒸馏训练,温度为7

  • 先求得hard_loss就是用学生模型预测的标签和真实标签进行求得损失。
  • 再求soft_loss就是用学生模型预测的标签和老师模型预测的标签进行求得损失。使用softmax时候还需要进行除以温度temp。
  • 最后反向传播,求解模型
python 复制代码
       for epoch in range(opt.n_epochs):  ## epoch:5

            for data, targets in tqdm(train_dataloader):
                data = data.to(device)
                targets = targets.to(device)

                # 老师模型预测
                with torch.no_grad():
                    teacher_preds = teacher_model(data)

                # 学生模型预测
                student_preds = model(data)
                # 计算hard_loss
                student_loss = hard_loss(student_preds, targets)

                # 计算蒸馏后的预测损失
                ditillation_loss = soft_loss(
                    F.softmax(student_preds / temp, dim=1),
                    F.softmax(teacher_preds / temp, dim=1)
                )

                loss = alpha * student_loss + (1 - alpha) * ditillation_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            model.eval()
            num_correct = 0
            num_samples = 0
            with torch.no_grad():
                for x, y in test_dataloader:
                    x = x.to(device)
                    y = y.to(device)

                    preds = model(x)

                    predictions = preds.max(1).indices
                    num_correct += (predictions == y).sum()
                    num_samples += predictions.size(0)
                acc = (num_correct / num_samples).item()

        model.train()
        print('distillation: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

(4)整个蒸馏训练代码

python 复制代码
import torch.nn as nn
import torch
import torch.nn.functional as F
import os
from tqdm import tqdm

from dist.StudentModel import StudentModel
from dist.TeacherModel import TeacherModel

weights = 'C:/Users/26394/PycharmProjects/untitled1//dist/params/teacher.pth'

# D_weight_path = 'C:/Users/26394/PycharmProjects/untitled1/dist/params/distillation.pth'
## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.backends.cudnn.benchmark = True #使用卷积cuDNN加速


class DistillationTrainer():
    def __init__(self,opt,train_dataloader,test_dataloader):
        self.opt = opt
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader


    def trainer(self):
        opt = self.opt
        train_dataloader = self.train_dataloader
        test_dataloader = self.test_dataloader

        #拿取训练好的模型
        teacher_model = TeacherModel()
        if os.path.exists(weights):
            teacher_model.load_state_dict(torch.load(weights))
            print('successfully')
        else:
            print('not loading')
        teacher_model = teacher_model.to(device)
        teacher_model.eval()

        model = StudentModel()
        model = model.to(device)

        temp = 7

        # hard_loss
        hard_loss = nn.CrossEntropyLoss()
        # hard_loss权重
        alpha = 0.3

        # soft_loss
        soft_loss = nn.KLDivLoss(reduction="batchmean")

        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

        for epoch in range(opt.n_epochs):  ## epoch:5

            for data, targets in tqdm(train_dataloader):
                data = data.to(device)
                targets = targets.to(device)

                # 老师模型预测
                with torch.no_grad():
                    teacher_preds = teacher_model(data)

                # 学生模型预测
                student_preds = model(data)
                # 计算hard_loss
                student_loss = hard_loss(student_preds, targets)

                # 计算蒸馏后的预测损失
                ditillation_loss = soft_loss(
                    F.softmax(student_preds / temp, dim=1),
                    F.softmax(teacher_preds / temp, dim=1)
                )

                loss = alpha * student_loss + (1 - alpha) * ditillation_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            model.eval()
            num_correct = 0
            num_samples = 0
            with torch.no_grad():
                for x, y in test_dataloader:
                    x = x.to(device)
                    y = y.to(device)

                    preds = model(x)

                    predictions = preds.max(1).indices
                    num_correct += (predictions == y).sum()
                    num_samples += predictions.size(0)
                acc = (num_correct / num_samples).item()

        model.train()
        print('distillation: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

(5)蒸馏训练的主函数

该部分大致与teacher模型训练类似,只是调用不同。

python 复制代码
import argparse

import torch

from dist.DistillationTrainer import DistillationTrainer
from dist.MyDateLoader import MyDataset
from dist.TeacherTrainer import TeacherTrainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def paras():
    ## 超参数配置
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epochs", type=int, default=5, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")
    opt = parser.parse_args()
    ## opt = parser.parse_args(args=[])                 ## 在colab中运行时,换为此行
    print(opt)
    return opt


if __name__ == '__main__':
    opt = paras()
    data = MyDataset(opt)
    train_dataloader, test_dataloader = data.MyData()

    # 训练Teacher模型
    # teacher_trainer = TeacherTrainer(opt,train_dataloader,test_dataloader)
    # teacher_trainer.trainer()

    distillation_trainer = DistillationTrainer(opt,train_dataloader,test_dataloader)
    distillation_trainer.trainer()

三、总结

总的来说,知识蒸馏是一种有效的模型压缩技术,可以通过在模型训练过程中引入额外的监督信号来训练简化的模型,从而获得与大型复杂模型相近的性能,但具有更小的模型尺寸和计算开销。

相关推荐
斯多葛的信徒3 分钟前
看看你的电脑可以跑 AI 模型吗?
人工智能·语言模型·电脑·llama
正在走向自律3 分钟前
AI 写作(六):核心技术与多元应用(6/10)
人工智能·aigc·ai写作
AI科技大本营4 分钟前
Anthropic四大专家“会诊”:实现深度思考不一定需要多智能体,AI完美对齐比失控更可怕!...
人工智能·深度学习
Cc不爱吃洋葱4 分钟前
如何本地部署AI智能体平台,带你手搓一个AI Agent
人工智能·大语言模型·agent·ai大模型·ai agent·智能体·ai智能体
网安打工仔4 分钟前
斯坦福李飞飞最新巨著《AI Agent综述》
人工智能·自然语言处理·大模型·llm·agent·ai大模型·大模型入门
AGI学习社5 分钟前
2024中国排名前十AI大模型进展、应用案例与发展趋势
linux·服务器·人工智能·华为·llama
AI_Tool5 分钟前
纳米AI搜索官网 - 新一代智能答案引擎
人工智能·搜索引擎
Damon小智5 分钟前
合合信息DocFlow产品解析与体验:人人可搭建的AI自动化单据处理工作流
图像处理·人工智能·深度学习·机器学习·ai·自动化·docflow
小虚竹6 分钟前
用AI辅导侄女大学物理的质点运动学问题
人工智能·chatgpt
猿类崛起@7 分钟前
百度千帆大模型实战:AI大模型开发的调用指南
人工智能·学习·百度·大模型·产品经理·大模型学习·大模型教程