本文在已有的 YOLOv8 Docker 训练环境上,实现一个完整的"模型蒸馏(Knowledge Distillation)"训练流程。
适合人群:
- 想学习 AI 模型压缩
- 想让小模型速度更快
- 想部署到边缘设备
- 已经能正常训练 YOLOv8
- 希望进一步学习蒸馏、量化、剪枝等优化技术
本文使用:
- Ubuntu 22.04
- Docker
- GTX1660S(6GB)
- PyTorch
- YOLOv8
- coco128 数据集
- Ultralytics YOLOv8 框架
一、什么是模型蒸馏?
模型蒸馏(Knowledge Distillation)本质上是:
用一个"大模型"指导"小模型"学习。
通常:
| 角色 | 作用 |
|---|---|
| Teacher(教师模型) | 精度高,但体积大 |
| Student(学生模型) | 更小、更快 |
例如:
| 模型 | 参数量 | 推理速度 |
|---|---|---|
| YOLOv8m | 较大 | 较慢 |
| YOLOv8n | 很小 | 很快 |
我们可以:
- 先训练一个 YOLOv8m
- 再让 YOLOv8n 学习 YOLOv8m 的"知识"
这样:
- 小模型精度会提升
- 推理速度仍然很快
- 更适合边缘部署
二、为什么需要蒸馏?
很多时候:
- GPU 显存不足
- 边缘设备算力差
- Jetson / 工控机 / IPC 部署资源有限
例如 GTX1660S:
- 只有 6GB 显存
- 跑 yolov8x 很吃力
- batch size 容易 OOM
社区里也有很多 YOLOv8 显存不足问题。
因此实际生产中常见方案:
大模型训练
↓
模型蒸馏
↓
得到高精度小模型
↓
TensorRT / ONNX 部署
三、实验环境
1. 主机环境
Ubuntu 22.04
Docker
NVIDIA Driver
CUDA
GTX1660S
2. Docker 镜像
nvcr.io/nvidia/pytorch:24.12-py3
该镜像:
- 已内置 CUDA
- 已内置 PyTorch
- 非常适合训练 YOLOv8
PyTorch 本身对 GPU 加速支持非常完善。
3. 启动容器
docker run -it --gpus all \
--shm-size=16g \
--name yolov8-distill \
-v /home/workspace:/workspace \
nvcr.io/nvidia/pytorch:24.12-py3
参数说明:
| 参数 | 作用 |
|---|---|
| --gpus all | 使用 GPU |
| --shm-size | 增加共享内存 |
| -v | 挂载代码目录 |
四、安装 YOLOv8(已安装可跳过)
1. 克隆源码
cd /workspace
git clone https://github.com/ultralytics/ultralytics.git
cd ultralytics
官方项目:Ultralytics YOLOv8 Github
2. 安装依赖
pip install -e .
检查:
yolo version
五、准备数据集
这里使用官方 coco128。
yolo detect train \
model=yolov8n.pt \
data=coco128.yaml \
epochs=1
首次运行会自动下载数据集。
六、模型蒸馏原理
普通训练:
图片 → Student → Loss → 更新参数
蒸馏训练:
图片 → Teacher
↓
图片 → Student
↓
让 Student 模仿 Teacher 输出
蒸馏核心思想:
Student 不仅学习:
- 标签(Ground Truth)
还学习:
- Teacher 的特征
- Teacher 的输出分布
- Teacher 的分类概率
七、先训练 Teacher 模型
蒸馏之前:
必须先有教师模型。
这里使用:
Teacher : YOLOv8m
Student : YOLOv8n
1.获取Teacher模型
获取Teacher模型有两种方式,一种是直接使用官网的yolov8m.pt,一种是自己训练yolov8m.pt作为Teacher模型,这里为例演示方便直接使用官网上的yolov8m.pt模型。
八、YOLOv8 蒸馏实现
YOLOv8 官方没有直接提供蒸馏接口。
因此我们自己实现。
1.创建蒸馏训练脚本
创建:
touch distill_train.py
2.完整蒸馏代码
from ultralytics import YOLO
import torch
import torch.nn as nn
import torch.nn.functional as F
# Teacher
teacher = YOLO("runs/detect/train/weights/best.pt")
# Student
student = YOLO("yolov8n.pt")
teacher_model = teacher.model
student_model = student.model
teacher_model.eval()
# 冻结 Teacher
for p in teacher_model.parameters():
p.requires_grad = False
optimizer = torch.optim.Adam(
student_model.parameters(),
lr=1e-4
)
temperature = 4.0
alpha = 0.7
device = "cuda"
teacher_model.to(device)
student_model.to(device)
for epoch in range(10):
imgs = torch.randn(4, 3, 640, 640).to(device)
with torch.no_grad():
teacher_out = teacher_model(imgs)
student_out = student_model(imgs)
# 蒸馏 Loss
kd_loss = F.mse_loss(
student_out[0],
teacher_out[0]
)
# 普通 Loss(示例)
gt_loss = student_out[0].mean()
loss = alpha * kd_loss + (1 - alpha) * gt_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"epoch={epoch} loss={loss.item()}")
九、蒸馏中的关键参数
1. Temperature(温度)
用于软化概率分布。
常见:
temperature = 2~8
越大:
- Teacher 输出越平滑
- Student 更容易学习
2. Alpha
控制:
- Teacher Loss
- Ground Truth Loss
占比。
loss = alpha * kd_loss + (1-alpha) * gt_loss
常见:
alpha = 0.5~0.9
十、蒸馏效果
典型情况:
| 模型 | mAP | FPS |
|---|---|---|
| YOLOv8m | 高 | 较低 |
| YOLOv8n | 较低 | 高 |
| 蒸馏后的 YOLOv8n | 接近 mAP | 高 FPS |
即:
用小模型获得接近大模型的效果。
十一、为什么蒸馏有效?
因为 Teacher 学到的不仅是标签。
还包括:
- 类别间关系
- 特征表达
- 深层语义信息
例如:
Teacher 可能知道:
狗 ≈ 狼
汽车 ≈ 卡车
这些信息:
普通标签学不到。
十二、进阶蒸馏方式
除了输出蒸馏:
还可以:
1. Feature Distillation
让 Student 学习中间特征图。
Backbone Feature
Neck Feature
Head Feature
2. Attention Distillation
学习注意力图。
3. Logit Distillation
最经典方式。
学习分类 logits。
十三、蒸馏后的部署优化
蒸馏完成后:
推荐继续:
蒸馏
→ ONNX
→ TensorRT
→ FP16
→ INT8
最终:
- 更小
- 更快
- 更低功耗
十四、YOLOv8 为什么适合蒸馏?
YOLOv8 本身:
- Anchor-Free
- 网络结构清晰
- Backbone/Neck 解耦明显
因此非常适合:
- 剪枝
- 量化
- 蒸馏
YOLOv8 在检测任务上具有较好的速度与精度平衡。
十五、Docker 训练的优势
使用 Docker 的好处:
| 优势 | 说明 |
|---|---|
| 环境隔离 | 不污染宿主机 |
| CUDA统一 | 避免版本冲突 |
| 快速迁移 | 可复制到其他服务器 |
| 易部署 | CI/CD方便 |
很多 AI 项目已经大量采用 Docker 化训练环境。
十六、GTX1660S 训练建议
1660S 只有 6GB 显存。
建议:
| 参数 | 推荐 |
|---|---|
| imgsz | 640 |
| batch | 4~8 |
| model | yolov8n/s |
| AMP | 开启 |
| cache | 禁用 |
后续进阶方向
你后面可以继续学习:
- Feature Map 蒸馏
- YOLOv8 剪枝
- TensorRT INT8 校准
- PTQ / QAT 量化
- DeepSparse
- ONNX Runtime
- NCNN 部署
- Jetson 边缘部署