CANN 模型蒸馏实战:大模型知识迁移到小模型


一、什么是知识蒸馏

1.1 核心思想

大模型(Teacher)的"软标签"包含类间关系等丰富信息,小模型(Student)通过学习这些软标签,可以获得接近大模型的效果:

复制代码
硬标签 (Hard Label):
  猫: 1.0, 狗: 0.0, 鸟: 0.0
  → 只知道"是猫"

软标签 (Soft Label, Teacher 输出):
  猫: 0.85, 狗: 0.10, 鸟: 0.05
  → 知道"像猫,但和狗有点像"

蒸馏: Student 学习软标签,获得 Teacher 的"知识"

1.2 蒸馏损失函数

复制代码
总损失 = α × 蒸馏损失 + (1-α) × 硬标签损失

蒸馏损失 = KL散度(Softmax(Teacher/T), Softmax(Student/T)) × T²
硬标签损失 = CrossEntropy(Student, HardLabel)

其中:
  T: 温度参数 (Temperature),控制软标签的平滑程度
  α: 蒸馏损失权重,通常 0.5-0.9
  T²: 梯度缩放因子

二、温度参数的作用

2.1 温度对软标签的影响

python 复制代码
import torch
import torch.nn.functional as F

def visualize_temperature_effect(logits, temperatures=[1, 2, 5, 10, 20]):
    """可视化温度对软标签的影响"""
    
    print(f"原始 logits: {logits}")
    print()
    
    for T in temperatures:
        soft_labels = F.softmax(logits / T, dim=0)
        print(f"T={T:2d}: {soft_labels.detach().numpy()}")

# 示例
logits = torch.tensor([5.0, 3.0, 1.0, 0.5])
visualize_temperature_effect(logits)

# 输出:
# T= 1: [0.8429 0.1142 0.0155 0.0085]  ← 非常尖锐
# T= 2: [0.6895 0.1841 0.0686 0.0578]  ← 稍微平滑
# T= 5: [0.4747 0.2319 0.1470 0.1464]  ← 更平滑
# T=10: [0.3775 0.2447 0.1889 0.1889]  ← 接近均匀
# T=20: [0.3183 0.2527 0.2145 0.2145]  ← 非常平滑

2.2 温度选择建议

温度 效果 适用场景
T=1 软标签接近硬标签 Teacher 置信度高
T=3-5 适度平滑 通用场景,推荐
T=10-20 高度平滑 类间关系复杂
T→∞ 接近均匀分布 几乎不用

三、昇腾蒸馏训练实现

3.1 Teacher-Student 架构

python 复制代码
import torch
import torch.nn as nn
import torch.npu

class DistillationTrainer:
    def __init__(self, teacher, student, temperature=5.0, alpha=0.7):
        self.teacher = teacher
        self.student = student
        self.temperature = temperature
        self.alpha = alpha
        
        # Teacher 冻结
        self.teacher.eval()
        for param in self.teacher.parameters():
            param.requires_grad = False
        
        # Student 可训练
        self.student.train()
    
    def train_step(self, input_data, labels):
        """单步蒸馏训练"""
        
        # Teacher 前向 (不计算梯度)
        with torch.no_grad():
            teacher_logits = self.teacher(input_data)
        
        # Student 前向
        student_logits = self.student(input_data)
        
        # 蒸馏损失
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
        soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
        
        distillation_loss = F.kl_div(
            soft_student, soft_teacher,
            reduction='batchmean'
        ) * (self.temperature ** 2)
        
        # 硬标签损失
        hard_loss = F.cross_entropy(student_logits, labels)
        
        # 总损失
        total_loss = self.alpha * distillation_loss + (1 - self.alpha) * hard_loss
        
        return total_loss, {
            'distillation_loss': distillation_loss.item(),
            'hard_loss': hard_loss.item(),
            'total_loss': total_loss.item()
        }
    
    def train_epoch(self, dataloader, optimizer, epoch):
        """训练一个 epoch"""
        self.student.train()
        
        total_loss = 0
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.npu(), target.npu()
            
            optimizer.zero_grad()
            loss, metrics = self.train_step(data, target)
            loss.backward()
            
            optimizer.step()
            total_loss += loss.item()
            
            if batch_idx % 100 == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}: "
                      f"Loss={metrics['total_loss']:.4f}, "
                      f"Distill={metrics['distillation_loss']:.4f}, "
                      f"Hard={metrics['hard_loss']:.4f}")
        
        return total_loss / len(dataloader)

3.2 完整训练流程

python 复制代码
def run_distillation():
    """执行蒸馏训练"""
    
    # 1. 加载 Teacher
    teacher = load_pretrained_model("resnet50_teacher.pt")
    teacher = teacher.npu()
    
    # 2. 创建 Student
    student = ResNet18(num_classes=100)
    student = student.npu()
    
    # 3. 初始化蒸馏训练器
    trainer = DistillationTrainer(
        teacher=teacher,
        student=student,
        temperature=5.0,
        alpha=0.7
    )
    
    # 4. 数据集
    train_dataset = CIFAR100(root='./data', train=True, download=True)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    
    # 5. 优化器
    optimizer = torch.optim.SGD(
        student.parameters(),
        lr=0.01,
        momentum=0.9,
        weight_decay=5e-4
    )
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
    
    # 6. 训练循环
    for epoch in range(50):
        avg_loss = trainer.train_epoch(train_loader, optimizer, epoch)
        scheduler.step()
        
        print(f"Epoch {epoch+1}/50, Avg Loss: {avg_loss:.4f}")
        
        # 评估
        if (epoch + 1) % 10 == 0:
            accuracy = evaluate(student, test_loader)
            print(f"Student Accuracy: {accuracy:.2f}%")
    
    # 7. 保存 Student
    torch.save(student.state_dict(), "resnet18_distilled.pt")
    print("蒸馏完成,模型已保存")

run_distillation()

四、多 Teacher 蒸馏

4.1 多 Teacher 架构

python 复制代码
class MultiTeacherDistillation:
    def __init__(self, teachers, student, temperature=5.0, alpha=0.7):
        self.teachers = teachers  # 多个 Teacher
        self.student = student
        self.temperature = temperature
        self.alpha = alpha
        
        # 冻结所有 Teacher
        for teacher in self.teachers:
            teacher.eval()
            for param in teacher.parameters():
                param.requires_grad = False
    
    def train_step(self, input_data, labels):
        """多 Teacher 蒸馏"""
        
        # 收集所有 Teacher 的输出
        teacher_logits_list = []
        with torch.no_grad():
            for teacher in self.teachers:
                logits = teacher(input_data)
                teacher_logits_list.append(logits)
        
        # 聚合 Teacher 输出 (平均)
        teacher_logits = torch.stack(teacher_logits_list).mean(dim=0)
        
        # Student 前向
        student_logits = self.student(input_data)
        
        # 蒸馏损失
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
        soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
        
        distillation_loss = F.kl_div(
            soft_student, soft_teacher,
            reduction='batchmean'
        ) * (self.temperature ** 2)
        
        # 硬标签损失
        hard_loss = F.cross_entropy(student_logits, labels)
        
        # 总损失
        total_loss = self.alpha * distillation_loss + (1 - self.alpha) * hard_loss
        
        return total_loss

# 使用示例
teachers = [
    load_model("resnet50_v1.pt"),
    load_model("resnet50_v2.pt"),
    load_model("resnet101.pt"),
]

student = ResNet18(num_classes=100)

trainer = MultiTeacherDistillation(teachers, student)

五、蒸馏效果评估

5.1 评估指标

python 复制代码
def evaluate_distillation(teacher, student, test_loader):
    """评估蒸馏效果"""
    
    teacher_acc = evaluate_accuracy(teacher, test_loader)
    student_acc = evaluate_accuracy(student, test_loader)
    
    # 蒸馏效率
    efficiency = student_acc / teacher_acc * 100
    
    # 模型大小对比
    teacher_params = sum(p.numel() for p in teacher.parameters())
    student_params = sum(p.numel() for p in student.parameters())
    compression_ratio = teacher_params / student_params
    
    print(f"Teacher 准确率: {teacher_acc:.2f}%")
    print(f"Student 准确率: {student_acc:.2f}%")
    print(f"蒸馏效率: {efficiency:.1f}%")
    print(f"参数压缩比: {compression_ratio:.1f}x")
    print(f"Teacher 参数量: {teacher_params:,}")
    print(f"Student 参数量: {student_params:,}")
    
    return {
        'teacher_acc': teacher_acc,
        'student_acc': student_acc,
        'efficiency': efficiency,
        'compression_ratio': compression_ratio
    }

# 示例输出:
# Teacher 准确率: 95.20%
# Student 准确率: 91.80%
# 蒸馏效率: 96.4%
# 参数压缩比: 25.5x
# Teacher 参数量: 25,557,032
# Student 参数量: 1,002,812

5.2 与直接训练对比

python 复制代码
def compare_training_methods():
    """对比不同训练方法"""
    
    results = {
        'direct_train': {'accuracy': 89.5, 'params': '1M'},
        'distillation': {'accuracy': 91.8, 'params': '1M'},
        'pretrain_finetune': {'accuracy': 92.1, 'params': '1M'},
    }
    
    print("训练方法对比:")
    print(f"{'方法':<20} {'准确率':<10} {'参数量':<10}")
    print("-" * 40)
    for method, result in results.items():
        print(f"{method:<20} {result['accuracy']:<10} {result['params']:<10}")
    
    # 输出:
    # 训练方法对比:
    # 方法                 准确率      参数量
    # ---------------------------------------
    # direct_train         89.5       1M
    # distillation         91.8       1M
    # pretrain_finetune    92.1       1M

compare_training_methods()

六、常见问题

问题 原因 解决方案
Student 效果差 温度不合适 调整温度参数
蒸馏损失不下降 Teacher 太强 降低 Teacher 复杂度
训练不稳定 学习率太高 降低学习率
过拟合 数据量不足 增加数据增强
Teacher 和 Student 差异大 架构差异大 选择相近的 Teacher

相关仓库

相关推荐
俊哥工具1 小时前
解决网速卡顿、断网、网络报错,万能网络修复工具教程
网络·python·django·计算机外设·智能路由器·pygame
WL_Aurora1 小时前
Python爬虫实战(九):百度百聘招聘数据采集
爬虫·python·百度
lili00121 小时前
Gemini 3.5发布后的AI格局:谷歌重新定义行业标准
java·人工智能·python·ai编程
JunLa1 小时前
Java语法糖
java·python·哈希算法
财经资讯数据_灵砚智能1 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(夜间-次晨)2026年5月21日
大数据·人工智能·python·信息可视化·自然语言处理
水木流年追梦1 小时前
大模型入门-RL基础
开发语言·python·算法·leetcode·正则表达式
Cthy_hy1 小时前
基于首届中国互联网数据挖掘竞赛数据集的行为相似网络分析
python·信息可视化·数据挖掘
AI玫瑰助手1 小时前
Python运算符:逻辑运算符(and/or/not)的短路特性
开发语言·python·信息可视化
是梦终空1 小时前
计算机源码274—基于深度学习的中医舌象智能识别与健康管理系统(源代码+数据库+12000字论文)
人工智能·python·深度学习·opencv·django·vue·springboot