一、什么是知识蒸馏
1.1 核心思想
大模型(Teacher)的"软标签"包含类间关系等丰富信息,小模型(Student)通过学习这些软标签,可以获得接近大模型的效果:
复制代码
硬标签 (Hard Label):
猫: 1.0, 狗: 0.0, 鸟: 0.0
→ 只知道"是猫"
软标签 (Soft Label, Teacher 输出):
猫: 0.85, 狗: 0.10, 鸟: 0.05
→ 知道"像猫,但和狗有点像"
蒸馏: Student 学习软标签,获得 Teacher 的"知识"
1.2 蒸馏损失函数
复制代码
总损失 = α × 蒸馏损失 + (1-α) × 硬标签损失
蒸馏损失 = KL散度(Softmax(Teacher/T), Softmax(Student/T)) × T²
硬标签损失 = CrossEntropy(Student, HardLabel)
其中:
T: 温度参数 (Temperature),控制软标签的平滑程度
α: 蒸馏损失权重,通常 0.5-0.9
T²: 梯度缩放因子
二、温度参数的作用
2.1 温度对软标签的影响
python
复制代码
import torch
import torch.nn.functional as F
def visualize_temperature_effect(logits, temperatures=[1, 2, 5, 10, 20]):
"""可视化温度对软标签的影响"""
print(f"原始 logits: {logits}")
print()
for T in temperatures:
soft_labels = F.softmax(logits / T, dim=0)
print(f"T={T:2d}: {soft_labels.detach().numpy()}")
# 示例
logits = torch.tensor([5.0, 3.0, 1.0, 0.5])
visualize_temperature_effect(logits)
# 输出:
# T= 1: [0.8429 0.1142 0.0155 0.0085] ← 非常尖锐
# T= 2: [0.6895 0.1841 0.0686 0.0578] ← 稍微平滑
# T= 5: [0.4747 0.2319 0.1470 0.1464] ← 更平滑
# T=10: [0.3775 0.2447 0.1889 0.1889] ← 接近均匀
# T=20: [0.3183 0.2527 0.2145 0.2145] ← 非常平滑
2.2 温度选择建议
| 温度 |
效果 |
适用场景 |
| T=1 |
软标签接近硬标签 |
Teacher 置信度高 |
| T=3-5 |
适度平滑 |
通用场景,推荐 |
| T=10-20 |
高度平滑 |
类间关系复杂 |
| T→∞ |
接近均匀分布 |
几乎不用 |
三、昇腾蒸馏训练实现
3.1 Teacher-Student 架构
python
复制代码
import torch
import torch.nn as nn
import torch.npu
class DistillationTrainer:
def __init__(self, teacher, student, temperature=5.0, alpha=0.7):
self.teacher = teacher
self.student = student
self.temperature = temperature
self.alpha = alpha
# Teacher 冻结
self.teacher.eval()
for param in self.teacher.parameters():
param.requires_grad = False
# Student 可训练
self.student.train()
def train_step(self, input_data, labels):
"""单步蒸馏训练"""
# Teacher 前向 (不计算梯度)
with torch.no_grad():
teacher_logits = self.teacher(input_data)
# Student 前向
student_logits = self.student(input_data)
# 蒸馏损失
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
distillation_loss = F.kl_div(
soft_student, soft_teacher,
reduction='batchmean'
) * (self.temperature ** 2)
# 硬标签损失
hard_loss = F.cross_entropy(student_logits, labels)
# 总损失
total_loss = self.alpha * distillation_loss + (1 - self.alpha) * hard_loss
return total_loss, {
'distillation_loss': distillation_loss.item(),
'hard_loss': hard_loss.item(),
'total_loss': total_loss.item()
}
def train_epoch(self, dataloader, optimizer, epoch):
"""训练一个 epoch"""
self.student.train()
total_loss = 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.npu(), target.npu()
optimizer.zero_grad()
loss, metrics = self.train_step(data, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
if batch_idx % 100 == 0:
print(f"Epoch {epoch}, Batch {batch_idx}: "
f"Loss={metrics['total_loss']:.4f}, "
f"Distill={metrics['distillation_loss']:.4f}, "
f"Hard={metrics['hard_loss']:.4f}")
return total_loss / len(dataloader)
3.2 完整训练流程
python
复制代码
def run_distillation():
"""执行蒸馏训练"""
# 1. 加载 Teacher
teacher = load_pretrained_model("resnet50_teacher.pt")
teacher = teacher.npu()
# 2. 创建 Student
student = ResNet18(num_classes=100)
student = student.npu()
# 3. 初始化蒸馏训练器
trainer = DistillationTrainer(
teacher=teacher,
student=student,
temperature=5.0,
alpha=0.7
)
# 4. 数据集
train_dataset = CIFAR100(root='./data', train=True, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# 5. 优化器
optimizer = torch.optim.SGD(
student.parameters(),
lr=0.01,
momentum=0.9,
weight_decay=5e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
# 6. 训练循环
for epoch in range(50):
avg_loss = trainer.train_epoch(train_loader, optimizer, epoch)
scheduler.step()
print(f"Epoch {epoch+1}/50, Avg Loss: {avg_loss:.4f}")
# 评估
if (epoch + 1) % 10 == 0:
accuracy = evaluate(student, test_loader)
print(f"Student Accuracy: {accuracy:.2f}%")
# 7. 保存 Student
torch.save(student.state_dict(), "resnet18_distilled.pt")
print("蒸馏完成,模型已保存")
run_distillation()
四、多 Teacher 蒸馏
4.1 多 Teacher 架构
python
复制代码
class MultiTeacherDistillation:
def __init__(self, teachers, student, temperature=5.0, alpha=0.7):
self.teachers = teachers # 多个 Teacher
self.student = student
self.temperature = temperature
self.alpha = alpha
# 冻结所有 Teacher
for teacher in self.teachers:
teacher.eval()
for param in teacher.parameters():
param.requires_grad = False
def train_step(self, input_data, labels):
"""多 Teacher 蒸馏"""
# 收集所有 Teacher 的输出
teacher_logits_list = []
with torch.no_grad():
for teacher in self.teachers:
logits = teacher(input_data)
teacher_logits_list.append(logits)
# 聚合 Teacher 输出 (平均)
teacher_logits = torch.stack(teacher_logits_list).mean(dim=0)
# Student 前向
student_logits = self.student(input_data)
# 蒸馏损失
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
distillation_loss = F.kl_div(
soft_student, soft_teacher,
reduction='batchmean'
) * (self.temperature ** 2)
# 硬标签损失
hard_loss = F.cross_entropy(student_logits, labels)
# 总损失
total_loss = self.alpha * distillation_loss + (1 - self.alpha) * hard_loss
return total_loss
# 使用示例
teachers = [
load_model("resnet50_v1.pt"),
load_model("resnet50_v2.pt"),
load_model("resnet101.pt"),
]
student = ResNet18(num_classes=100)
trainer = MultiTeacherDistillation(teachers, student)
五、蒸馏效果评估
5.1 评估指标
python
复制代码
def evaluate_distillation(teacher, student, test_loader):
"""评估蒸馏效果"""
teacher_acc = evaluate_accuracy(teacher, test_loader)
student_acc = evaluate_accuracy(student, test_loader)
# 蒸馏效率
efficiency = student_acc / teacher_acc * 100
# 模型大小对比
teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())
compression_ratio = teacher_params / student_params
print(f"Teacher 准确率: {teacher_acc:.2f}%")
print(f"Student 准确率: {student_acc:.2f}%")
print(f"蒸馏效率: {efficiency:.1f}%")
print(f"参数压缩比: {compression_ratio:.1f}x")
print(f"Teacher 参数量: {teacher_params:,}")
print(f"Student 参数量: {student_params:,}")
return {
'teacher_acc': teacher_acc,
'student_acc': student_acc,
'efficiency': efficiency,
'compression_ratio': compression_ratio
}
# 示例输出:
# Teacher 准确率: 95.20%
# Student 准确率: 91.80%
# 蒸馏效率: 96.4%
# 参数压缩比: 25.5x
# Teacher 参数量: 25,557,032
# Student 参数量: 1,002,812
5.2 与直接训练对比
python
复制代码
def compare_training_methods():
"""对比不同训练方法"""
results = {
'direct_train': {'accuracy': 89.5, 'params': '1M'},
'distillation': {'accuracy': 91.8, 'params': '1M'},
'pretrain_finetune': {'accuracy': 92.1, 'params': '1M'},
}
print("训练方法对比:")
print(f"{'方法':<20} {'准确率':<10} {'参数量':<10}")
print("-" * 40)
for method, result in results.items():
print(f"{method:<20} {result['accuracy']:<10} {result['params']:<10}")
# 输出:
# 训练方法对比:
# 方法 准确率 参数量
# ---------------------------------------
# direct_train 89.5 1M
# distillation 91.8 1M
# pretrain_finetune 92.1 1M
compare_training_methods()
六、常见问题
| 问题 |
原因 |
解决方案 |
| Student 效果差 |
温度不合适 |
调整温度参数 |
| 蒸馏损失不下降 |
Teacher 太强 |
降低 Teacher 复杂度 |
| 训练不稳定 |
学习率太高 |
降低学习率 |
| 过拟合 |
数据量不足 |
增加数据增强 |
| Teacher 和 Student 差异大 |
架构差异大 |
选择相近的 Teacher |
相关仓库