基于Pytorch的CIFAR100数据集上从ResNet50到VGG16的知识蒸馏实验记录

知识蒸馏的概念

可以参照NeurIPS2015的论文"Distilling the Knowledge in a Neural Network"了解知识蒸馏的概念。
知识蒸馏的狭义概念就是从复杂模型中迁移知识来提升简单模型的性能。复杂模型称之为教师模型,简单模型称之为学生模型。最近,笔者重温了知识蒸馏的概念,并在CIFAR100数据集上对知识蒸馏进行了验证和实验。
logits,硬目标,软目标的概念:logits指的是网络最后一层的输出概率,硬目标指的是真值标签的one-hot编码,软目标指的是对logits进行softmax之后的概率。
加入温度系数的软目标,为了让softmax之后的概率分布更加软化,Hinton提出了使用了温度参数对logits进行softmax的软化处理,

T为温度,T越大,概率分布更加平缓。

数据集 CIFAR100,是一个经典的图像分类模型,有100个图像类别

数据集直接采用Pytorch定义的官方数据集进行加载
python 复制代码
import torchvision
from torchvision import transforms

CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

train_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train=True,
    transform=transform_train,
    download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train = False,
    transform=transform_test,
    download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)

分类模型:采用ResNet50作为教师模型,VGG16作为学生模型。

VGG16网络定义代码
python 复制代码
"""vgg in pytorch


[1] Karen Simonyan, Andrew Zisserman

    Very Deep Convolutional Networks for Large-Scale Image Recognition.
    https://arxiv.org/abs/1409.1556v6
"""
'''VGG11/13/16/19 in Pytorch.'''

import torch
import torch.nn as nn

cfg = {
    'A' : [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
    'B' : [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
    'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512,      'M'],
    'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}

class VGG(nn.Module):

    def __init__(self, features, num_class=100):
        super().__init__()
        self.features = features

        self.classifier = nn.Sequential(
            nn.Linear(512, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_class)
        )

    def forward(self, x):
        output = self.features(x)
        output = output.view(output.size()[0], -1)
        output = self.classifier(output)

        return output

def make_layers(cfg, batch_norm=False):
    layers = []

    input_channel = 3
    for l in cfg:
        if l == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            continue

        layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)]

        if batch_norm:
            layers += [nn.BatchNorm2d(l)]

        layers += [nn.ReLU(inplace=True)]
        input_channel = l

    return nn.Sequential(*layers)

def vgg16_bn():
    return VGG(make_layers(cfg['D'], batch_norm=True))
ResNet50网络定义代码
python 复制代码
"""resnet in pytorch



[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.

    Deep Residual Learning for Image Recognition
    https://arxiv.org/abs/1512.03385v1
"""

import torch
import torch.nn as nn

class BasicBlock(nn.Module):
    """Basic Block for resnet 18 and resnet 34

    """

    #BasicBlock and BottleNeck block
    #have different output size
    #we use class attribute expansion
    #to distinct
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        #residual function
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels * BasicBlock.expansion)
        )

        #shortcut
        self.shortcut = nn.Sequential()

        #the shortcut output dimension is not the same with residual function
        #use 1*1 convolution to match the dimension
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

class BottleNeck(nn.Module):
    """Residual block for resnet over 50 layers

    """
    expansion = 4
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels * BottleNeck.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

class ResNet(nn.Module):

    def __init__(self, block, num_block, num_classes=100):
        super().__init__()

        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))
        #we use a different inputsize than the original paper
        #so conv2_x's stride is 1
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        """make resnet layers(by layer i didnt mean this 'layer' was the
        same as a neuron netowork layer, ex. conv layer), one layer may
        contain more than one residual block

        Args:
            block: block type, basic block or bottle neck block
            out_channels: output depth channel number of this layer
            num_blocks: how many blocks per layer
            stride: the stride of the first block of this layer

        Return:
            return a resnet layer
        """

        # we have num_block blocks per layer, the first block
        # could be 1 or 2, other blocks would always be 1
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        output = self.conv1(x)
        output = self.conv2_x(output)
        output = self.conv3_x(output)
        output = self.conv4_x(output)
        output = self.conv5_x(output)
        output = self.avg_pool(output)
        output = output.view(output.size(0), -1)
        output = self.fc(output)

        return output

def resnet50():
    """ return a ResNet 50 object
    """
    return ResNet(BottleNeck, [3, 4, 6, 3])

先单独训练教师模型和学生模型,分别统计教师模型学生模型的精度

损失函数 nn.CrossEntropyLoss()
优化器 torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)
学习率曲线 torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)
epochs = 200
教师模型训练代码
python 复制代码
import torch
from torch import nn
from tqdm import tqdm
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from  my_resnet import resnet50


def TeacherModel():
    """ return a ResNet 50 object
    """
    model = resnet50()
    return model


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])


train_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train=True,
    transform=transform_train,
    download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train = False,
    transform=transform_test,
    download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)

if __name__ == "__main__":
    
    """
    从头训练教师模型
    """
    model = TeacherModel().to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)
    train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) #learning rate decay
    
   
    iter_per_epoch = len(train_loader)
    
    epochs = 200
    best_acc = 0.0
    global_step = 0
    for epoch in range(epochs):

        model.train()

        train_scheduler.step(epoch)

        for data, targets in tqdm(train_loader):
            
            data = data.to(device)
            targets = targets.to(device)


            optimizer.zero_grad()

            prediction = model(data)
            loss = criterion(prediction, targets)
            
            loss.backward()
            optimizer.step()
            global_step += 1

        model.eval()
        num_correct = 0
        num_samples = 0

        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                y = y.to(device)

                prediction = model(x)
                prediction = prediction.max(1).indices
                num_correct += (prediction == y).sum()
                num_samples += prediction.size(0)
            
            acc = (num_correct/num_samples).item()


        
        if acc > best_acc:
            torch.save(model.state_dict(), './weights/teacher_cifar100/teacher_{}.pth'.format(acc))
            best_acc = acc

        print("Epoch {}: 当前模型最佳精度为:{:.4f}".format(epoch, best_acc))
            
    """
    教师模型
    Epoch 199: 当前模型最佳精度为:0.7840
    """
教师模型的分类精度为78.40%
学生模型的训练代码
python 复制代码
import torch
from torch import nn
from tqdm import tqdm
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from my_vgg import vgg16_bn


def StudentModel():
    model = vgg16_bn()
    return model


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])


train_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train=True,
    transform=transform_train,
    download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train = False,
    transform=transform_test,
    download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)


if __name__ == "__main__":
    
    """
    从头训练学生模型
    """
    model = StudentModel().to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)
    train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) #learning rate decay

    iter_per_epoch = len(train_loader)
    
    epochs = 200
    best_acc = 0.0
    global_step = 0
    for epoch in range(epochs):

        model.train()

        train_scheduler.step(epoch)

        for data, targets in tqdm(train_loader):
            
            data = data.to(device)
            targets = targets.to(device)



            optimizer.zero_grad()

            prediction = model(data)
            loss = criterion(prediction, targets)
            
            loss.backward()
            optimizer.step()
            global_step += 1

        model.eval()
        num_correct = 0
        num_samples = 0

        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                y = y.to(device)

                prediction = model(x)
                prediction = prediction.max(1).indices
                num_correct += (prediction == y).sum()
                num_samples += prediction.size(0)
            
            acc = (num_correct/num_samples).item()


        
        if acc > best_acc:
            torch.save(model.state_dict(), './weights/student_cifar100_vgg16/student_{}.pth'.format(acc))
            best_acc = acc

        print("Epoch {}: 当前模型最佳精度为:{:.4f}".format(epoch, best_acc))
            
    """
    学生模型 VGG16
    Epoch 199: 当前模型最佳精度为:0.7121
    """
学生模型的训练精度为71.21%

教师-学生模型蒸馏训练,学生损失为CE交叉熵损失,蒸馏损失为KL散度损失

重点一:蒸馏学生损失loss=(1-alpha) * T * T * soft_loss + alpha * hard_loss,其中alpha为权重参数,T为Temperature温度参数,用于软目标化

具体可参见 bilibili视频

重点二:蒸馏损失的计算方式,student_predictions需要处以温度参数后进行F.log_softmax变成软目标,teacher_predictions需要处以温度参数

distillation_loss = soft_loss(F.log_softmax(student_predictions / Temp, dim=1), F.softmax(teacher_predictions / Temp, dim=1))

重点三:教师模型需要eval(), 得到教师模型输出需要 with torch.no_grad()和.detach()
python 复制代码
with torch.no_grad():
   teacher_predictions = teacher_model(data)
   teacher_predictions = teacher_predictions.detach() 
重点四:损失权重参数alpha和温度系数T的设定,笔者参照bilibili视频的设定,设置alpha为0.3,温度系数T为4
蒸馏训练代码
python 复制代码
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
from teacher_cifar100 import TeacherModel
from vgg_student_cifar100 import StudentModel

torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.backends.cudnn.benchmark = True






CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])


#load MNIST datasets
train_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train=True,
    transform=transform_train,
    download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(
    root = "./dataset/",
    train = False,
    transform=transform_test,
    download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)


if __name__ == "__main__":
    
    """
    从头训练教师模型
    """
    teacher_model = TeacherModel().to(device).eval()
    teacher_model.load_state_dict(torch.load("./weights/teacher_cifar100/teacher_0.7839999794960022.pth"))
    student_model = StudentModel().to(device)

    Temp = 4
    alpha = 0.3

    hard_loss = nn.CrossEntropyLoss()

    soft_loss = nn.KLDivLoss(reduction='batchmean')

    optimizer = torch.optim.SGD(student_model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)
    train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) #learning rate decay

    iter_per_epoch = len(train_loader)
    
    epochs = 200

    best_acc = 0.0
    global_step = 0
    for epoch in range(epochs):
        
        student_model.train()

        
        train_scheduler.step(epoch)

        for data, targets in tqdm(train_loader):
            
            data = data.to(device)
            targets = targets.to(device)


            optimizer.zero_grad()

            
            #教师预测
            with torch.no_grad():
                teacher_predictions = teacher_model(data)
                teacher_predictions = teacher_predictions.detach() #参照https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757

            student_predictions = student_model(data)
            
            student_loss = hard_loss(student_predictions, targets)

            distillation_loss = soft_loss(
                F.log_softmax(student_predictions / Temp, dim=1),  ##参照https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757
                F.softmax(teacher_predictions / Temp, dim=1)
            )

            loss = (1 - alpha) * Temp * Temp * distillation_loss + alpha * student_loss #T2 参照https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757
            loss.backward()
            optimizer.step()
            
            global_step += 1
            

        student_model.eval()
        num_correct = 0
        num_samples = 0

        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                y = y.to(device)

                prediction = student_model(x)
                prediction = prediction.max(1).indices
                num_correct += (prediction == y).sum()
                num_samples += prediction.size(0)
            
            acc = (num_correct/num_samples).item()
        
            
        if acc > best_acc:
            torch.save(student_model.state_dict(), './weights/knowledge_distillation_cifar100_vgg16/student_{}.pth'.format(acc))
            best_acc = acc

        print("Epoch {}: 当前模型最佳精度为:{:.4f}".format(epoch, best_acc))
        
    """
    蒸馏学生模型  ResNet50 --> VGG16
    ResNet50  当前模型最佳精度为:0.7840
    VGG16     当前模型最佳精度为:0.7121
    Temp = 4  alpha = 0.3   Acc  Epoch 199: 当前模型最佳精度为:0.7388
    """

知识蒸馏实验对比结果

模型 网络结构 分类精度
学生模型 VGG16 71.21%
教师模型 ResNet50 78.40%
蒸馏学生模型 VGG16 73.88%

实验总结分析

通过在CIFAR100数据集上的从ResNet50到VGG16的教师-学生模型的蒸馏实验,表明了Hinton等人提出的知识蒸馏的有效性。同时,通过实验的细节设置,笔者注意到了知识蒸馏的几个设置,soft_loss的计算有F.softmax和F.log_softmax的区别,教师模型需要eval和detach消除梯度,温度参数T和损失平衡系数alpha的选择,soft_loss需要乘以T2的系数,都是需要注意的细节问题。

致谢

1\] [Geoffrey Hinton, Oriol Vinyals, and Jeff Dean, "Distilling the Knowledge in a Neural Network," in NeurIPS 2025.](https://arxiv.org/abs/1503.02531) \[2\] \[3\] [https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click\&vd_source=e71c4eae27444c44f2de6239f04c4757](https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757)

相关推荐
华子w90892585912 分钟前
基于 Python Django 和 Spark 的电力能耗数据分析系统设计与实现7000字论文实现
python·spark·django
风铃喵游25 分钟前
让大模型调用MCP服务变得超级简单
前端·人工智能
Rockson38 分钟前
使用Ruby接入实时行情API教程
javascript·python
booooooty43 分钟前
基于Spring AI Alibaba的多智能体RAG应用
java·人工智能·spring·多智能体·rag·spring ai·ai alibaba
PyAIExplorer1 小时前
基于 OpenCV 的图像 ROI 切割实现
人工智能·opencv·计算机视觉
风口猪炒股指标1 小时前
技术分析、超短线打板模式与情绪周期理论,在市场共识的形成、分歧、瓦解过程中缘起性空的理解
人工智能·博弈论·群体博弈·人生哲学·自我引导觉醒
ai_xiaogui2 小时前
一键部署AI工具!用AIStarter快速安装ComfyUI与Stable Diffusion
人工智能·stable diffusion·部署ai工具·ai应用市场教程·sd快速部署·comfyui一键安装
Tipriest_2 小时前
Python关键字梳理
python·关键字·keyword
聚客AI3 小时前
Embedding进化论:从Word2Vec到OpenAI三代模型技术跃迁
人工智能·llm·掘金·日新计划
weixin_387545643 小时前
深入解析 AI Gateway:新一代智能流量控制中枢
人工智能·gateway