【nlp】知识蒸馏Distilling

一、知识蒸馏介绍

1. 什么是知识蒸馏?

知识蒸馏(Knowledge Distillation) 是一种用于模型压缩的技术,通过让小模型(称为学生模型,student model)从大模型(称为教师模型,teacher model)中学习,从而提高小模型的性能,同时保留大模型的一部分知识。知识蒸馏常用于深度学习中,以减少计算资源和内存需求,使得模型可以在资源受限的设备上运行,比如移动设备和嵌入式系统。

2. 轻量化网络方式有哪些?

1. 压缩已训练好的模型

  • 知识蒸馏:将大模型的知识传递给小模型,通过模仿大模型的输出提高小模型的性能。
  • 权值量化:将浮点数表示的权重转换为低精度整数(如 INT8)表示,减少模型体积和计算量。
  • 权重剪枝:移除不重要的权重或神经元,减少参数量和计算开销。
  • 通道剪枝:剪掉卷积层的某些通道,降低卷积计算的复杂性。
  • 注意力迁移:通过让小模型学习大模型的注意力机制,使其更好地关注重要的特征。

2. 直接训练轻量化网络

  • SqueezeNet:使用较少的参数量进行等效卷积操作。
  • MobileNetv1/v2/v3:引入深度可分离卷积(depthwise separable convolution)和倒残差结构,显著减少计算量。
  • MnasNet:通过神经架构搜索(NAS)设计的轻量化网络。
  • ShuffleNet:通过通道洗牌来优化组卷积的性能。
  • Xception:一种极度优化的深度可分离卷积网络。
  • EfficientNet:通过复合缩放(compound scaling)策略优化网络深度、宽度和分辨率。
  • EfficientDet:专门针对目标检测任务的轻量化网络,基于 EfficientNet 设计。

3. 加速卷积运算

  • im2col + GEMM:通过将卷积运算转换为矩阵乘法(General Matrix Multiplication)来加速计算。
  • Winograd 算法:用于减少卷积计算中的乘法操作,提升速度。
  • 低秩分解:将卷积核进行分解,减少参数量和计算量。

4. 硬件部署

  • TensorRT:NVIDIA 的深度学习推理库,通过优化模型来加速推理。
  • Jetson:NVIDIA 的嵌入式 AI 计算平台,适合低功耗场景。
  • TensorFlow-Slim:TensorFlow 中的轻量化网络构建工具,用于快速构建轻量模型。
  • OpenVINO:Intel 的推理工具套件,专注于边缘设备上的高效推理。
  • FPGA 集成电路:通过定制的集成电路实现高效的并行化计算,加速推理。

这些技术方法组合使用,可以在保持模型性能的同时大幅减少计算资源和存储需求,适合资源受限的应用场景如移动设备和嵌入式系统。

3. 软标签 vs 硬标签

  • 硬标签(hard targets):通常是训练数据的真实标签,通常采用 one-hot 编码。例如,图片分类任务中,图片所属的正确类别的概率为 1,其他类别的概率为 0。

  • 软标签(soft targets):通过教师模型的输出概率分布得到的标签。与硬标签不同,软标签是一个概率分布,包含了教师模型在所有类别上的预测概率。即使是错误类别,教师模型也会分配一个非零的概率。这些概率可以反映类别之间的相似性。

例如,对于一张图片,教师模型可能给出以下预测分布:

类别 A:70%
类别 B:20%
类别 C:5%
类别 D:5%

这表示该图片最有可能属于类别 A,但类别 B 也有一定的可能性。这样的概率分布提供了比硬标签(如 100% 属于类别 A)更多的细粒度信息。

4. 蒸馏温度 T T T

知识蒸馏中的温度作用

在标准的分类任务中,模型输出的是每个类别的预测概率,这些概率通常通过 Softmax 函数计算得到。Softmax 函数的定义如下:

P ( y = i ) = exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) P(y=i) = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)} P(y=i)=∑jexp(zj/T)exp(zi/T)

其中, z i z_i zi 是第 i i i 类的 logits(即模型输出的未归一化分数), T T T 是温度参数。当 T = 1 T=1 T=1 时,Softmax 函数表现为标准的形式。如果温度 T > 1 T > 1 T>1,Softmax 的输出将变得"更软",即各类之间的概率分布更加均匀;如果 T < 1 T < 1 T<1,Softmax 输出将变得更加尖锐,接近 one-hot 分布。

在知识蒸馏过程中,使用较高的温度(通常 T > 1 T>1 T>1)可以使教师模型输出的概率分布变得更加平滑,突出各类别之间的相对差异,而不是仅仅关注最高的概率类别。学生模型可以从这种软标签中学习到更多关于各个类别之间关系的信息,而不仅仅是从硬标签中学到的正确类别。

举例说明

假设我们有一个图像分类任务,教师模型是一种复杂的深度卷积神经网络(例如 ResNet-50),而学生模型是一个较小的模型(例如一个简单的卷积神经网络)。在知识蒸馏过程中,我们用教师模型的输出作为指导,来帮助学生模型学习。

  1. 无温度调整的硬标签训练: 学生模型仅从每个输入样本的真实类别标签(即 one-hot 编码)中学习,这些标签并没有包含类别之间的相关性或其他信息。

  2. 知识蒸馏中的软标签: 使用知识蒸馏时,首先通过调整温度参数 T > 1 T>1 T>1,教师模型输出的类别分布会变得更平滑。例如,对于某个输入图像,教师模型可能预测类别 A 的概率是 0.9,类别 B 的概率是 0.05,类别 C 的概率是 0.03,类别 D 的概率是 0.02。在温度调整后(例如 T = 5 T=5 T=5),这个分布可能会变为类别 A 的概率是 0.4,类别 B 的概率是 0.3,类别 C 的概率是 0.2,类别 D 的概率是 0.1。这个平滑后的分布反映了不同类别之间的相似性。

  3. 学生模型学习: 学生模型从这个更平滑的概率分布中学习,不仅学到了类别 A 的重要性,还学习到了类别 B 和类别 C 与类别 A 的相关性。这样可以帮助学生模型更好地理解数据之间的模式,从而提高泛化性能。

温度选择

温度 T T T 的选择非常关键,它决定了知识蒸馏的效果。较高的温度使得概率分布更平滑(矮胖),能够传递更多的类别信息,但也可能导致过度平滑,使得学生模型难以捕捉有用的信息。通常需要通过实验来确定最适合的温度值。

二、知识蒸馏过程

1. 知识蒸馏的过程

1. 输入数据

设输入数据为 x x x,同时输入给教师模型和学生模型。

2. 教师模型输出(Teacher Model Output)

教师模型是一个较复杂的神经网络,其通过 softmax 函数生成软标签 。softmax 函数使用温度参数 T T T 来控制输出概率的平滑度:

q i teacher = exp ⁡ ( z i teacher / T ) ∑ j exp ⁡ ( z j teacher / T ) q_i^{\text{teacher}} = \frac{\exp(z_i^{\text{teacher}} / T)}{\sum_j \exp(z_j^{\text{teacher}} / T)} qiteacher=∑jexp(zjteacher/T)exp(ziteacher/T)

其中:

  • q i teacher q_i^{\text{teacher}} qiteacher 是教师模型生成的类别 i i i 的概率。
  • z i teacher z_i^{\text{teacher}} ziteacher 是教师模型的第 i i i 类别的 logit 值。
  • T T T 是温度参数,当 T > 1 T > 1 T>1 时,输出概率分布更加平滑,有助于学生模型学习类别间的相似性。

3. 学生模型输出(Student Model Output)

学生模型是一个较小的模型,它通过学习教师模型的软标签和真实标签来提高性能。学生模型的输出也通过 softmax 函数生成。

软预测(Soft Predictions):

学生模型生成的软预测 是通过与教师模型相同温度 T T T 的 softmax 函数计算的:

q i student = exp ⁡ ( z i student / T ) ∑ j exp ⁡ ( z j student / T ) q_i^{\text{student}} = \frac{\exp(z_i^{\text{student}} / T)}{\sum_j \exp(z_j^{\text{student}} / T)} qistudent=∑jexp(zjstudent/T)exp(zistudent/T)

硬预测(Hard Predictions):

学生模型还生成硬预测 ,即通过正常的 softmax(温度 T = 1 T = 1 T=1)生成的标准输出,用于匹配真实标签:

q i hard = exp ⁡ ( z i student ) ∑ j exp ⁡ ( z j student ) q_i^{\text{hard}} = \frac{\exp(z_i^{\text{student}})}{\sum_j \exp(z_j^{\text{student}})} qihard=∑jexp(zjstudent)exp(zistudent)

4. 损失函数

为了训练学生模型,我们引入两个损失函数:

4.1 蒸馏损失(Distillation Loss)

蒸馏损失用于衡量学生模型的软预测和教师模型的软标签之间的差异。它通过使用**Kullback-Leibler 散度(KL 散度)**来度量这两个概率分布之间的距离:

L distill = KL ( q teacher , q student ) = ∑ i q i teacher log ⁡ ( q i teacher q i student ) L_{\text{distill}} = \text{KL}(q^{\text{teacher}}, q^{\text{student}}) = \sum_i q_i^{\text{teacher}} \log\left(\frac{q_i^{\text{teacher}}}{q_i^{\text{student}}}\right) Ldistill=KL(qteacher,qstudent)=i∑qiteacherlog(qistudentqiteacher)

其中:

  • q teacher q^{\text{teacher}} qteacher 是教师模型生成的软标签。
  • q student q^{\text{student}} qstudent 是学生模型生成的软预测。
4.2 学生损失(Student Loss)

学生损失是学生模型的硬预测与真实标签(硬标签)之间的差异,通常使用交叉熵损失计算:

L student = − ∑ i y i log ⁡ ( q i hard ) L_{\text{student}} = - \sum_i y_i \log(q_i^{\text{hard}}) Lstudent=−i∑yilog(qihard)

其中:

  • y i y_i yi 是真实标签的 one-hot 编码。
  • q i hard q_i^{\text{hard}} qihard 是学生模型对类别 i i i 的硬预测概率。

5. 总损失函数

最终的总损失函数是蒸馏损失学生损失的加权和:

L total = α L student + β L distill L_{\text{total}} = \alpha L_{\text{student}} + \beta L_{\text{distill}} Ltotal=αLstudent+βLdistill

其中:

  • α \alpha α 和 β \beta β 是权重系数,控制学生损失和蒸馏损失的相对重要性。通常 α \alpha α 可以设置为 1, β \beta β 可以调整以控制蒸馏的影响。
  • 为了确保梯度的缩放一致性,蒸馏损失部分的梯度通常会乘以 T 2 T^2 T2,因为软标签的梯度会随温度 T T T 缩放。

6. 温度参数 T T T 的影响

温度参数 T T T 控制了 softmax 函数的平滑程度。较高的 T T T 会使教师模型的输出概率分布更加平滑,从而让学生模型能够学习到类别间的相对关系。这些信息可以帮助学生模型提高泛化能力。

  • 当 T = 1 T = 1 T=1 时,softmax 输出接近 one-hot 编码,类别之间的相对信息较少。
  • 当 T > 1 T > 1 T>1 时,类别之间的概率差异缩小,学生模型 可以从这些更平滑的概率中学习到更多的信息。

前边的两个图是训练过程,后边一个图是预测过程。

2. 知识蒸馏发展趋势

  1. 教学助长
  2. 助教、多个老师、多个同学
  3. 知识的表示(中间层)、数据集蒸馏、对比学习
  4. 多模态、知识图谱、预训练大模型的知识蒸馏

论文:

Attention Transfer

channel-wise knowledge distillation for dense prediction

contrastive representation Distillation

Distill BERT

3. 实现代码

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


# 1. 定义教师模型(较大的模型)
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 2. 定义学生模型(较小的模型)
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 800)
        self.fc2 = nn.Linear(800, 800)
        self.fc3 = nn.Linear(800, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 3. 定义蒸馏损失函数
def distillation_loss(student_outputs, teacher_outputs, labels, T, alpha):
    """
    :param student_outputs: 学生模型的输出
    :param teacher_outputs: 教师模型的输出
    :param labels: 真实标签
    :param T: 温度参数
    :param alpha: 学生损失与蒸馏损失的权重
    :return: 总损失
    """
    # 计算学生模型的硬标签损失(交叉熵损失)
    hard_loss = F.cross_entropy(student_outputs, labels)

    # 计算软标签损失(KL 散度)
    soft_student = F.log_softmax(student_outputs / T, dim=1)
    soft_teacher = F.softmax(teacher_outputs / T, dim=1)
    soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T * T)

    # 总损失 = α * 硬损失 + (1 - α) * 软损失
    return alpha * hard_loss + (1 - alpha) * soft_loss


# 4. 数据加载
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)


# 5. 训练过程
def train_student(teacher_model, student_model, train_loader, optimizer, T, alpha, epochs=5):
    teacher_model.eval()  # 教师模型是预训练的,设置为 eval 模式
    student_model.train()  # 学生模型将要训练

    for epoch in range(epochs):
        total_loss = 0.0
        for images, labels in train_loader:
            # images, labels = images.cuda(), labels.cuda()
            # 教师模型预测
            with torch.no_grad():
                teacher_outputs = teacher_model(images)

            # 学生模型预测
            student_outputs = student_model(images)

            # 计算蒸馏损失
            loss = distillation_loss(student_outputs, teacher_outputs, labels, T, alpha)

            # 优化器更新
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")


# 6. 测试过程
def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            # images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs, 1) # 返回每行(每个样本)的最大值和对应的索引
            total += labels.size(0) #更新样本计数
            correct += (predicted == labels).sum().item() # 更新正确预测计数

    print(f'Test Accuracy: {100 * correct / total:.2f}%')


# 7. 实例化模型并启动训练
# teacher_model = TeacherModel().cuda()
# student_model = StudentModel().cuda()
teacher_model = TeacherModel()
student_model = StudentModel()

# 假设教师模型已经预训练过
# 这里可以加载预训练的教师模型权重
# torch.load('teacher_model.pth', teacher_model)

optimizer = optim.Adam(student_model.parameters(), lr=0.001)

# 设定温度 T 和 α 参数
T = 20  # 温度
alpha = 0.7  # 学生损失与蒸馏损失的权重

# 训练学生模型
train_student(teacher_model, student_model, train_loader, optimizer, T, alpha, epochs=5)

# 测试学生模型
test_model(student_model, test_loader)

# 保存学生模型权重
torch.save(student_model.state_dict(),'student_model.pth')
# 保存教师模型权重
torch.save(teacher_model.state_dict(), 'teacher_model.pth')

输出:

python 复制代码
Epoch [1/5], Loss: 0.4672
Epoch [2/5], Loss: 0.3833
Epoch [3/5], Loss: 0.3671
Epoch [4/5], Loss: 0.3566
Epoch [5/5], Loss: 0.3509
Test Accuracy: 98.44%

Process finished with exit code 0
相关推荐
小陈phd2 小时前
OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
人工智能·opencv·计算机视觉
Guofu_Liao3 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
ZHOU_WUYI7 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt
如若1237 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
老艾的AI世界8 小时前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲
DK221518 小时前
机器学习系列----关联分析
人工智能·机器学习
Robot2518 小时前
Figure 02迎重大升级!!人形机器人独角兽[Figure AI]商业化加速
人工智能·机器人·微信公众平台
浊酒南街9 小时前
Statsmodels之OLS回归
人工智能·数据挖掘·回归
畅联云平台9 小时前
美畅物联丨智能分析,安全管控:视频汇聚平台助力智慧工地建设
人工智能·物联网
加密新世界9 小时前
优化 Solana 程序
人工智能·算法·计算机视觉