基于Pytorch的CIFAR100数据集上从ResNet50到VGG16的知识蒸馏实验记录

知识蒸馏的概念

可以参照NeurIPS2015的论文"Distilling the Knowledge in a Neural Network"了解知识蒸馏的概念。
知识蒸馏的狭义概念就是从复杂模型中迁移知识来提升简单模型的性能。复杂模型称之为教师模型,简单模型称之为学生模型。最近,笔者重温了知识蒸馏的概念,并在CIFAR100数据集上对知识蒸馏进行了验证和实验。
logits,硬目标,软目标的概念:logits指的是网络最后一层的输出概率,硬目标指的是真值标签的one-hot编码,软目标指的是对logits进行softmax之后的概率。
加入温度系数的软目标,为了让softmax之后的概率分布更加软化,Hinton提出了使用了温度参数对logits进行softmax的软化处理,

T为温度,T越大,概率分布更加平缓。

数据集 CIFAR100,是一个经典的图像分类模型,有100个图像类别

数据集直接采用Pytorch定义的官方数据集进行加载
python 复制代码
import torchvision
from torchvision import transforms

CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

train_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train=True,
    transform=transform_train,
    download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train = False,
    transform=transform_test,
    download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)

分类模型:采用ResNet50作为教师模型,VGG16作为学生模型。

VGG16网络定义代码
python 复制代码
"""vgg in pytorch


[1] Karen Simonyan, Andrew Zisserman

    Very Deep Convolutional Networks for Large-Scale Image Recognition.
    https://arxiv.org/abs/1409.1556v6
"""
'''VGG11/13/16/19 in Pytorch.'''

import torch
import torch.nn as nn

cfg = {
    'A' : [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
    'B' : [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
    'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512,      'M'],
    'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}

class VGG(nn.Module):

    def __init__(self, features, num_class=100):
        super().__init__()
        self.features = features

        self.classifier = nn.Sequential(
            nn.Linear(512, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_class)
        )

    def forward(self, x):
        output = self.features(x)
        output = output.view(output.size()[0], -1)
        output = self.classifier(output)

        return output

def make_layers(cfg, batch_norm=False):
    layers = []

    input_channel = 3
    for l in cfg:
        if l == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            continue

        layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)]

        if batch_norm:
            layers += [nn.BatchNorm2d(l)]

        layers += [nn.ReLU(inplace=True)]
        input_channel = l

    return nn.Sequential(*layers)

def vgg16_bn():
    return VGG(make_layers(cfg['D'], batch_norm=True))
ResNet50网络定义代码
python 复制代码
"""resnet in pytorch



[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.

    Deep Residual Learning for Image Recognition
    https://arxiv.org/abs/1512.03385v1
"""

import torch
import torch.nn as nn

class BasicBlock(nn.Module):
    """Basic Block for resnet 18 and resnet 34

    """

    #BasicBlock and BottleNeck block
    #have different output size
    #we use class attribute expansion
    #to distinct
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        #residual function
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels * BasicBlock.expansion)
        )

        #shortcut
        self.shortcut = nn.Sequential()

        #the shortcut output dimension is not the same with residual function
        #use 1*1 convolution to match the dimension
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

class BottleNeck(nn.Module):
    """Residual block for resnet over 50 layers

    """
    expansion = 4
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels * BottleNeck.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

class ResNet(nn.Module):

    def __init__(self, block, num_block, num_classes=100):
        super().__init__()

        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))
        #we use a different inputsize than the original paper
        #so conv2_x's stride is 1
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        """make resnet layers(by layer i didnt mean this 'layer' was the
        same as a neuron netowork layer, ex. conv layer), one layer may
        contain more than one residual block

        Args:
            block: block type, basic block or bottle neck block
            out_channels: output depth channel number of this layer
            num_blocks: how many blocks per layer
            stride: the stride of the first block of this layer

        Return:
            return a resnet layer
        """

        # we have num_block blocks per layer, the first block
        # could be 1 or 2, other blocks would always be 1
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        output = self.conv1(x)
        output = self.conv2_x(output)
        output = self.conv3_x(output)
        output = self.conv4_x(output)
        output = self.conv5_x(output)
        output = self.avg_pool(output)
        output = output.view(output.size(0), -1)
        output = self.fc(output)

        return output

def resnet50():
    """ return a ResNet 50 object
    """
    return ResNet(BottleNeck, [3, 4, 6, 3])

先单独训练教师模型和学生模型,分别统计教师模型学生模型的精度

损失函数 nn.CrossEntropyLoss()
优化器 torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)
学习率曲线 torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)
epochs = 200
教师模型训练代码
python 复制代码
import torch
from torch import nn
from tqdm import tqdm
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from  my_resnet import resnet50


def TeacherModel():
    """ return a ResNet 50 object
    """
    model = resnet50()
    return model


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])


train_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train=True,
    transform=transform_train,
    download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train = False,
    transform=transform_test,
    download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)

if __name__ == "__main__":
    
    """
    从头训练教师模型
    """
    model = TeacherModel().to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)
    train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) #learning rate decay
    
   
    iter_per_epoch = len(train_loader)
    
    epochs = 200
    best_acc = 0.0
    global_step = 0
    for epoch in range(epochs):

        model.train()

        train_scheduler.step(epoch)

        for data, targets in tqdm(train_loader):
            
            data = data.to(device)
            targets = targets.to(device)


            optimizer.zero_grad()

            prediction = model(data)
            loss = criterion(prediction, targets)
            
            loss.backward()
            optimizer.step()
            global_step += 1

        model.eval()
        num_correct = 0
        num_samples = 0

        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                y = y.to(device)

                prediction = model(x)
                prediction = prediction.max(1).indices
                num_correct += (prediction == y).sum()
                num_samples += prediction.size(0)
            
            acc = (num_correct/num_samples).item()


        
        if acc > best_acc:
            torch.save(model.state_dict(), './weights/teacher_cifar100/teacher_{}.pth'.format(acc))
            best_acc = acc

        print("Epoch {}: 当前模型最佳精度为:{:.4f}".format(epoch, best_acc))
            
    """
    教师模型
    Epoch 199: 当前模型最佳精度为:0.7840
    """
教师模型的分类精度为78.40%
学生模型的训练代码
python 复制代码
import torch
from torch import nn
from tqdm import tqdm
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from my_vgg import vgg16_bn


def StudentModel():
    model = vgg16_bn()
    return model


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])


train_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train=True,
    transform=transform_train,
    download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train = False,
    transform=transform_test,
    download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)


if __name__ == "__main__":
    
    """
    从头训练学生模型
    """
    model = StudentModel().to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)
    train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) #learning rate decay

    iter_per_epoch = len(train_loader)
    
    epochs = 200
    best_acc = 0.0
    global_step = 0
    for epoch in range(epochs):

        model.train()

        train_scheduler.step(epoch)

        for data, targets in tqdm(train_loader):
            
            data = data.to(device)
            targets = targets.to(device)



            optimizer.zero_grad()

            prediction = model(data)
            loss = criterion(prediction, targets)
            
            loss.backward()
            optimizer.step()
            global_step += 1

        model.eval()
        num_correct = 0
        num_samples = 0

        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                y = y.to(device)

                prediction = model(x)
                prediction = prediction.max(1).indices
                num_correct += (prediction == y).sum()
                num_samples += prediction.size(0)
            
            acc = (num_correct/num_samples).item()


        
        if acc > best_acc:
            torch.save(model.state_dict(), './weights/student_cifar100_vgg16/student_{}.pth'.format(acc))
            best_acc = acc

        print("Epoch {}: 当前模型最佳精度为:{:.4f}".format(epoch, best_acc))
            
    """
    学生模型 VGG16
    Epoch 199: 当前模型最佳精度为:0.7121
    """
学生模型的训练精度为71.21%

教师-学生模型蒸馏训练,学生损失为CE交叉熵损失,蒸馏损失为KL散度损失

重点一:蒸馏学生损失loss=(1-alpha) * T * T * soft_loss + alpha * hard_loss,其中alpha为权重参数,T为Temperature温度参数,用于软目标化

具体可参见 bilibili视频

重点二:蒸馏损失的计算方式,student_predictions需要处以温度参数后进行F.log_softmax变成软目标,teacher_predictions需要处以温度参数

distillation_loss = soft_loss(F.log_softmax(student_predictions / Temp, dim=1), F.softmax(teacher_predictions / Temp, dim=1))

重点三:教师模型需要eval(), 得到教师模型输出需要 with torch.no_grad()和.detach()
python 复制代码
with torch.no_grad():
   teacher_predictions = teacher_model(data)
   teacher_predictions = teacher_predictions.detach() 
重点四:损失权重参数alpha和温度系数T的设定,笔者参照bilibili视频的设定,设置alpha为0.3,温度系数T为4
蒸馏训练代码
python 复制代码
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
from teacher_cifar100 import TeacherModel
from vgg_student_cifar100 import StudentModel

torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.backends.cudnn.benchmark = True






CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])


#load MNIST datasets
train_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train=True,
    transform=transform_train,
    download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train = False,
    transform=transform_test,
    download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)


if __name__ == "__main__":
    
    """
    从头训练教师模型
    """
    teacher_model = TeacherModel().to(device).eval()
    teacher_model.load_state_dict(torch.load("./weights/teacher_cifar100/teacher_0.7839999794960022.pth"))
    student_model = StudentModel().to(device)

    Temp = 4
    alpha = 0.3

    hard_loss = nn.CrossEntropyLoss()

    soft_loss = nn.KLDivLoss(reduction='batchmean')

    optimizer = torch.optim.SGD(student_model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)
    train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) #learning rate decay

    iter_per_epoch = len(train_loader)
    
    epochs = 200

    best_acc = 0.0
    global_step = 0
    for epoch in range(epochs):
        
        student_model.train()

        
        train_scheduler.step(epoch)

        for data, targets in tqdm(train_loader):
            
            data = data.to(device)
            targets = targets.to(device)


            optimizer.zero_grad()

            
            #教师预测
            with torch.no_grad():
                teacher_predictions = teacher_model(data)
                teacher_predictions = teacher_predictions.detach() #参照https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757

            student_predictions = student_model(data)
            
            student_loss = hard_loss(student_predictions, targets)

            distillation_loss = soft_loss(
                F.log_softmax(student_predictions / Temp, dim=1),  ##参照https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757
                F.softmax(teacher_predictions / Temp, dim=1)
            )

            loss = (1 - alpha) * Temp * Temp * distillation_loss + alpha * student_loss #T2 参照https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757
            loss.backward()
            optimizer.step()
            
            global_step += 1
            

        student_model.eval()
        num_correct = 0
        num_samples = 0

        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                y = y.to(device)

                prediction = student_model(x)
                prediction = prediction.max(1).indices
                num_correct += (prediction == y).sum()
                num_samples += prediction.size(0)
            
            acc = (num_correct/num_samples).item()
        
            
        if acc > best_acc:
            torch.save(student_model.state_dict(), './weights/knowledge_distillation_cifar100_vgg16/student_{}.pth'.format(acc))
            best_acc = acc

        print("Epoch {}: 当前模型最佳精度为:{:.4f}".format(epoch, best_acc))
        
    """
    蒸馏学生模型  ResNet50 --> VGG16
    ResNet50  当前模型最佳精度为:0.7840
    VGG16     当前模型最佳精度为:0.7121
    Temp = 4  alpha = 0.3   Acc  Epoch 199: 当前模型最佳精度为:0.7388
    """

知识蒸馏实验对比结果

模型 网络结构 分类精度
学生模型 VGG16 71.21%
教师模型 ResNet50 78.40%
蒸馏学生模型 VGG16 73.88%

实验总结分析

通过在CIFAR100数据集上的从ResNet50到VGG16的教师-学生模型的蒸馏实验,表明了Hinton等人提出的知识蒸馏的有效性。同时,通过实验的细节设置,笔者注意到了知识蒸馏的几个设置,soft_loss的计算有F.softmax和F.log_softmax的区别,教师模型需要eval和detach消除梯度,温度参数T和损失平衡系数alpha的选择,soft_loss需要乘以T^2^的系数,都是需要注意的细节问题。

致谢

[1] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean, "Distilling the Knowledge in a Neural Network," in NeurIPS 2025.

[2] https://github.com/weiaicunzai/pytorch-cifar100

[3] https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757

相关推荐
沐知全栈开发12 分钟前
基于PyTorch的DDSP设计源码及C/C++实现分析
c++·pytorch·源码分析·ddsp·c/c++实现
int WINGsssss13 分钟前
使用猴子补丁对pytorch的分布式接口进行插桩
人工智能·pytorch·python·猴子补丁
YRr YRr14 分钟前
为什么在PyTorch中需要添加批次维度
人工智能·pytorch·python
爱喝热水的呀哈喽15 分钟前
KAN解possion 方程,方程构造篇代码阅读
人工智能·pytorch·python
YRr YRr15 分钟前
详解 PyTorch 图像预处理:使用 torchvision.transforms 实现有效的数据处理
人工智能·pytorch·python
取个名字真难呐17 分钟前
14、保存与加载PyTorch训练的模型和超参数
pytorch·深度学习·机器学习
baijin_cha21 分钟前
PyTorch基础学习03_数学运算&自动微分
人工智能·pytorch·笔记·机器学习
啊哈哈哈哈哈啊哈哈22 分钟前
P3打卡-pytorch实现天气识别
人工智能·pytorch·python
蚝油菜花32 分钟前
WebDreamer:基于大语言模型模拟网页交互增强网络规划能力的框架
人工智能·开源·llm
z千鑫34 分钟前
【人工智能】深入解析GPT、BERT与Transformer模型|从原理到应用的完整教程
人工智能·gpt·bert