蒸馏知识点笔记

蒸馏(Distillation)

模型蒸馏是一种通过将大模型(教师模型)的知识传递给小模型(学生模型)来优化小模型性能的方法。蒸馏通常包括以下几种形式:

1. 软标签蒸馏(Soft Label Distillation)

通过教师模型的软标签(soft labels)来训练学生模型,使学生模型学习教师模型的输出分布。

python 复制代码
import torch
import torch.nn as nn

# 定义教师模型和学生模型
teacher_model = ...
student_model = ...

# 定义损失函数
criterion = nn.KLDivLoss(reduction='batchmean')

# 教师模型生成软标签
teacher_model.eval()
with torch.no_grad():
    teacher_outputs = teacher_model(inputs)
soft_labels = torch.softmax(teacher_outputs / temperature, dim=1)

# 学生模型预测
student_outputs = student_model(inputs)
loss = criterion(torch.log_softmax(student_outputs / temperature, dim=1), soft_labels)

# 反向传播和优化
loss.backward()
optimizer.step()

2. 特征蒸馏(Feature Distillation)

通过让学生模型学习教师模型中间层的特征表示来优化学生模型性能。

python 复制代码
class FeatureExtractor(nn.Module):
    def __init__(self, model):
        super(FeatureExtractor, self).__init__()
        self.features = nn.Sequential(*list(model.children())[:-1])
    
    def forward(self, x):
        return self.features(x)

teacher_feature_extractor = FeatureExtractor(teacher_model)
student_feature_extractor = FeatureExtractor(student_model)

# 获取特征表示
teacher_features = teacher_feature_extractor(inputs)
student_features = student_feature_extractor(inputs)

# 定义特征蒸馏损失
feature_distillation_loss = nn.MSELoss()(student_features, teacher_features)

# 反向传播和优化
feature_distillation_loss.backward()
optimizer.step()

3. 组合蒸馏(Combined Distillation)

结合软标签蒸馏和特征蒸馏,利用教师模型的输出分布和特征表示来训练学生模型。

python 复制代码
# 定义损失函数
criterion = nn.KLDivLoss(reduction='batchmean')
mse_loss = nn.MSELoss()

# 教师模型生成软标签
teacher_model.eval()
with torch.no_grad():
    teacher_outputs = teacher_model(inputs)
soft_labels = torch.softmax(teacher_outputs / temperature, dim=1)

# 学生模型预测
student_outputs = student_model(inputs)
soft_label_loss = criterion(torch.log_softmax(student_outputs / temperature, dim=1), soft_labels)

# 获取特征表示
teacher_features = teacher_feature_extractor(inputs)
student_features = student_feature_extractor(inputs)
feature_loss = mse_loss(student_features, teacher_features)

# 组合损失
total_loss = soft_label_loss + alpha * feature_loss

# 反向传播和优化
total_loss.backward()
optimizer.step()

通过上述蒸馏技术,可以有效地优化模型结构,减少计算开销,并在保持模型性能的前提下,提高模型的推理速度和部署效率。

相关推荐
AA陈超几秒前
ASC学习笔记0020:用于定义角色或Actor的默认属性值
c++·笔记·学习·ue5·虚幻引擎
无敌最俊朗@1 小时前
力扣hot100-206反转链表
算法·leetcode·链表
TsingtaoAI1 小时前
企业实训|自动驾驶中的图像处理与感知技术——某央企汽车集团
图像处理·人工智能·自动驾驶·集成学习
Kuo-Teng1 小时前
LeetCode 279: Perfect Squares
java·数据结构·算法·leetcode·职场和发展
王哈哈^_^1 小时前
YOLO11实例分割训练任务——从构建数据集到训练的完整教程
人工智能·深度学习·算法·yolo·目标检测·机器学习·计算机视觉
IMPYLH1 小时前
Lua 的 collectgarbage 函数
开发语言·笔记·junit·单元测试·lua
百锦再1 小时前
第18章 高级特征
android·java·开发语言·后端·python·rust·django
檐下翻书1732 小时前
从入门到精通:流程图制作学习路径规划
论文阅读·人工智能·学习·算法·流程图·论文笔记
源码之家2 小时前
基于Python房价预测系统 数据分析 Flask框架 爬虫 随机森林回归预测模型、链家二手房 可视化大屏 大数据毕业设计(附源码)✅
大数据·爬虫·python·随机森林·数据分析·spark·flask
CoderYanger2 小时前
B.双指针——3194. 最小元素和最大元素的最小平均值
java·开发语言·数据结构·算法·leetcode·职场和发展·1024程序员节