知识蒸馏实战代码教学二(代码实战部分)

一、上章原理回顾

具体过程:

(1)首先我们要先训练出较大模型既teacher模型。(在图中没有出现)

(2)再对teacher模型进行蒸馏,此时我们已经有一个训练好的teacher模型,所以我们能很容易知道teacher模型输入特征x之后,预测出来的结果teacher_preds标签。

(3)此时,求到老师预测结果之后,我们需要求解学生在训练过程中的每一次结果student_preds标签。

(4)先求hard_loss,也就是学生模型的预测student_preds与真实标签targets之间的损失。

(5)再求soft_loss,也就是学生模型的预测student_preds与教师模型teacher_preds的预测之间的损失。

(6)求出hard_loss与soft_loss之后,求和总loss=a*hard_loss + (1-a)soft_loss,a是一个自己设置的权重参数,我在代码中设置为a=0.3。

(7)最后反向传播继续迭代。

二、代码实现

1、数据集

数据集采用的是手写数字的数据集mnist数据集,如果没有下载,代码部分中会进行下载,只需要把download改成True,然后就会保存在当前目录中。该数据集将其分成80%的训练集,20%的测试集,最后返回train_dataset和test_datatset。

python 复制代码
class MyDataset(Dataset):
    def __init__(self,opt):
        self.opt = opt

    def MyData(self):
        ## mnist数据集下载0
        mnist = datasets.MNIST(
            root='../datasets/', train=True, download=False, transform=transforms.Compose(
                [transforms.Resize(self.opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        )

        dataset_size = len(mnist)
        train_size = int(0.8 * dataset_size)
        test_size = dataset_size - train_size

        train_dataset, test_dataset = random_split(mnist, [train_size, test_size])

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.opt.batch_size,
            shuffle=True,
        )

        test_dataloader = DataLoader(
            test_dataset,
            batch_size=self.opt.batch_size,
            shuffle=False,  # 在测试集上不需要打乱顺序
        )
        return train_dataloader,test_dataloader

2、teacher模型和训练实现

(1) 首先是teacher模型构造,经过三次线性层。

python 复制代码
import torch.nn as nn
import torch

img_area = 784

class TeacherModel(nn.Module):
    def __init__(self,in_channel=1,num_classes=10):
        super(TeacherModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(img_area,1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = x.view(-1, img_area)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

(2)训练teacher模型

老师模型训练完成后其权重参数会保存在teacher.pth当中,为以后调用。

python 复制代码
import torch.nn as nn
import torch


## 创建文件夹
from tqdm import tqdm

from dist.TeacherModel import TeacherModel

weight_path = 'C:/Users/26394/PycharmProjects/untitled1/dist/params/teacher.pth'
## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.backends.cudnn.benchmark = True #使用卷积cuDNN加速


class TeacherTrainer():
    def __init__(self,opt,train_dataloader,test_dataloader):
        self.opt = opt
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader

    def trainer(self):
        # 老师模型
        opt = self.opt
        train_dataloader = self.train_dataloader
        test_dataloader = self.test_dataloader

        teacher_model = TeacherModel()
        teacher_model = teacher_model.to(device)

        criterion = nn.CrossEntropyLoss()
        optimizer_teacher = torch.optim.Adam(teacher_model.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

        for epoch in range(opt.n_epochs):  ## epoch:50
            teacher_model.train()

            for data, targets in tqdm(train_dataloader):
                data = data.to(device)
                targets = targets.to(device)

                preds = teacher_model(data)
                loss = criterion(preds, targets)

                optimizer_teacher.zero_grad()
                loss = criterion(preds, targets)
                loss.backward()
                optimizer_teacher.step()

            teacher_model.eval()
            num_correct = 0
            num_samples = 0
            with torch.no_grad():
                for x, y in test_dataloader:
                    x = x.to(device)
                    y = y.to(device)

                    preds = teacher_model(x)

                    predictions = preds.max(1).indices
                    num_correct += (predictions == y).sum()
                    num_samples += predictions.size(0)
                acc = (num_correct / num_samples).item()

            torch.save(teacher_model.state_dict(), weight_path)

        teacher_model.train()
        print('teacher: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

(3)训练teacher模型

模型参数都在paras()当中设置好了,直接调用teacher_model就行,然后将其权重参数会保存在teacher.pth当中。

python 复制代码
import argparse

import torch

from dist.DistillationTrainer import DistillationTrainer
from dist.MyDateLoader import MyDataset
from dist.TeacherTrainer import TeacherTrainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def paras():
    ## 超参数配置
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epochs", type=int, default=5, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")
    opt = parser.parse_args()
    ## opt = parser.parse_args(args=[])                 ## 在colab中运行时,换为此行
    print(opt)
    return opt


if __name__ == '__main__':
    opt = paras()
    data = MyDataset(opt)
    train_dataloader, test_dataloader = data.MyData()

    # 训练Teacher模型
    teacher_trainer = TeacherTrainer(opt,train_dataloader,test_dataloader)
    teacher_trainer.trainer()

3、学生模型的构建

学生模型也是经过了三次线性层,但是神经元没有teacher当中多。所以student模型会比teacher模型小很多。

python 复制代码
import torch.nn as nn
import torch

img_area = 784

class StudentModel(nn.Module):
    def __init__(self,in_channel=1,num_classes=10):
        super(StudentModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(img_area,20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_classes)

    def forward(self, x):
        x = x.view(-1, img_area)
        x = self.fc1(x)
        # x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        # x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

4、知识蒸馏训练

(1)首先读取teacher模型。

将teacher模型中的权重参数teacher.pth放入模型当中。

python 复制代码
 #拿取训练好的模型
        teacher_model = TeacherModel()
        if os.path.exists(weights):
            teacher_model.load_state_dict(torch.load(weights))
            print('successfully')
        else:
            print('not loading')
        teacher_model = teacher_model.to(device)

(2)设置损失求解的函数

hard_loss用的就是普通的交叉熵损失函数,而soft_loss就是用的KL散度。

python 复制代码
        # hard_loss
        hard_loss = nn.CrossEntropyLoss()
        # hard_loss权重
        alpha = 0.3

        # soft_loss
        soft_loss = nn.KLDivLoss(reduction="batchmean")

(3)之后再进行蒸馏训练,温度为7

  • 先求得hard_loss就是用学生模型预测的标签和真实标签进行求得损失。
  • 再求soft_loss就是用学生模型预测的标签和老师模型预测的标签进行求得损失。使用softmax时候还需要进行除以温度temp。
  • 最后反向传播,求解模型
python 复制代码
       for epoch in range(opt.n_epochs):  ## epoch:5

            for data, targets in tqdm(train_dataloader):
                data = data.to(device)
                targets = targets.to(device)

                # 老师模型预测
                with torch.no_grad():
                    teacher_preds = teacher_model(data)

                # 学生模型预测
                student_preds = model(data)
                # 计算hard_loss
                student_loss = hard_loss(student_preds, targets)

                # 计算蒸馏后的预测损失
                ditillation_loss = soft_loss(
                    F.softmax(student_preds / temp, dim=1),
                    F.softmax(teacher_preds / temp, dim=1)
                )

                loss = alpha * student_loss + (1 - alpha) * ditillation_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            model.eval()
            num_correct = 0
            num_samples = 0
            with torch.no_grad():
                for x, y in test_dataloader:
                    x = x.to(device)
                    y = y.to(device)

                    preds = model(x)

                    predictions = preds.max(1).indices
                    num_correct += (predictions == y).sum()
                    num_samples += predictions.size(0)
                acc = (num_correct / num_samples).item()

        model.train()
        print('distillation: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

(4)整个蒸馏训练代码

python 复制代码
import torch.nn as nn
import torch
import torch.nn.functional as F
import os
from tqdm import tqdm

from dist.StudentModel import StudentModel
from dist.TeacherModel import TeacherModel

weights = 'C:/Users/26394/PycharmProjects/untitled1//dist/params/teacher.pth'

# D_weight_path = 'C:/Users/26394/PycharmProjects/untitled1/dist/params/distillation.pth'
## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.backends.cudnn.benchmark = True #使用卷积cuDNN加速


class DistillationTrainer():
    def __init__(self,opt,train_dataloader,test_dataloader):
        self.opt = opt
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader


    def trainer(self):
        opt = self.opt
        train_dataloader = self.train_dataloader
        test_dataloader = self.test_dataloader

        #拿取训练好的模型
        teacher_model = TeacherModel()
        if os.path.exists(weights):
            teacher_model.load_state_dict(torch.load(weights))
            print('successfully')
        else:
            print('not loading')
        teacher_model = teacher_model.to(device)
        teacher_model.eval()

        model = StudentModel()
        model = model.to(device)

        temp = 7

        # hard_loss
        hard_loss = nn.CrossEntropyLoss()
        # hard_loss权重
        alpha = 0.3

        # soft_loss
        soft_loss = nn.KLDivLoss(reduction="batchmean")

        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

        for epoch in range(opt.n_epochs):  ## epoch:5

            for data, targets in tqdm(train_dataloader):
                data = data.to(device)
                targets = targets.to(device)

                # 老师模型预测
                with torch.no_grad():
                    teacher_preds = teacher_model(data)

                # 学生模型预测
                student_preds = model(data)
                # 计算hard_loss
                student_loss = hard_loss(student_preds, targets)

                # 计算蒸馏后的预测损失
                ditillation_loss = soft_loss(
                    F.softmax(student_preds / temp, dim=1),
                    F.softmax(teacher_preds / temp, dim=1)
                )

                loss = alpha * student_loss + (1 - alpha) * ditillation_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            model.eval()
            num_correct = 0
            num_samples = 0
            with torch.no_grad():
                for x, y in test_dataloader:
                    x = x.to(device)
                    y = y.to(device)

                    preds = model(x)

                    predictions = preds.max(1).indices
                    num_correct += (predictions == y).sum()
                    num_samples += predictions.size(0)
                acc = (num_correct / num_samples).item()

        model.train()
        print('distillation: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

(5)蒸馏训练的主函数

该部分大致与teacher模型训练类似,只是调用不同。

python 复制代码
import argparse

import torch

from dist.DistillationTrainer import DistillationTrainer
from dist.MyDateLoader import MyDataset
from dist.TeacherTrainer import TeacherTrainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def paras():
    ## 超参数配置
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epochs", type=int, default=5, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")
    opt = parser.parse_args()
    ## opt = parser.parse_args(args=[])                 ## 在colab中运行时,换为此行
    print(opt)
    return opt


if __name__ == '__main__':
    opt = paras()
    data = MyDataset(opt)
    train_dataloader, test_dataloader = data.MyData()

    # 训练Teacher模型
    # teacher_trainer = TeacherTrainer(opt,train_dataloader,test_dataloader)
    # teacher_trainer.trainer()

    distillation_trainer = DistillationTrainer(opt,train_dataloader,test_dataloader)
    distillation_trainer.trainer()

三、总结

总的来说,知识蒸馏是一种有效的模型压缩技术,可以通过在模型训练过程中引入额外的监督信号来训练简化的模型,从而获得与大型复杂模型相近的性能,但具有更小的模型尺寸和计算开销。

相关推荐
IT古董18 分钟前
【漫话机器学习系列】261.工具变量(Instrumental Variables)
人工智能·机器学习
小王格子22 分钟前
AI 编程革命:腾讯云 CodeBuddy 如何重塑开发效率?
人工智能·云计算·腾讯云·codebuddy·craft
MonkeyKing_sunyuhua33 分钟前
VSCode + Cline AI辅助编程完全指南
ide·人工智能·vscode
Leinwin39 分钟前
Microsoft Azure 服务4月更新告示
人工智能·azure
胡耀超43 分钟前
霍夫圆变换全面解析(OpenCV)
人工智能·python·opencv·算法·计算机视觉·数据挖掘·数据安全
jndingxin1 小时前
OpenCV CUDA 模块中用于在 GPU 上计算两个数组对应元素差值的绝对值函数absdiff(
人工智能·opencv·计算机视觉
jerry6091 小时前
LLM笔记(五)概率论
人工智能·笔记·学习·概率论
硅谷秋水1 小时前
学习以任务为中心的潜动作,随地采取行动
人工智能·深度学习·计算机视觉·语言模型·机器人
Tiny番茄1 小时前
Multimodal models —— CLIP,LLava,QWen
人工智能