YOLO分类任务训练教程:从数据准备到模型部署全流程

YOLO分类任务训练教程

前言

在深度学习领域,YOLO(You Only Look Once)系列模型以其高效的目标检测能力闻名于世。然而,很多人可能不知道,YOLO同样可以出色地完成图像分类任务。从YOLOv8开始,Ultralytics官方正式集成了分类模型的训练支持,使得我们可以用统一的框架、熟悉的代码风格来完成分类任务的训练。

图像分类是计算机视觉中最基础的任务之一,它的目标是将输入图像归入预定义的类别中。与目标检测需要标注边界框不同,分类任务只需要为每张图片标注一个类别标签,数据准备的门槛更低,非常适合作为入门深度学习的起点。

本文将以YOLOv8为例,从环境搭建、数据准备、模型训练、验证评估到推理部署,手把手带你完成一个完整的YOLO分类任务训练流程。


一、环境搭建

1.1 创建虚拟环境

建议使用Conda创建独立的Python虚拟环境,避免依赖冲突:

bash 复制代码
conda create -n yolo_cls python=3.10 -y
conda activate yolo_cls

1.2 安装Ultralytics

Ultralytics是YOLO的官方库,安装非常简单:

bash 复制代码
pip install ultralytics

安装完成后,可以验证是否安装成功:

bash 复制代码
yolo version

提示:如果你有GPU,请确保已正确安装CUDA和cuDNN,Ultralytics会自动检测并使用GPU加速训练。


二、数据集准备

YOLO分类任务的数据集组织方式非常直观,按照如下目录结构摆放即可:

复制代码
dataset/
├── train/
│   ├── cat/
│   │   ├── cat_001.jpg
│   │   ├── cat_002.jpg
│   │   └── ...
│   ├── dog/
│   │   ├── dog_001.jpg
│   │   ├── dog_002.jpg
│   │   └── ...
│   └── bird/
│       ├── bird_001.jpg
│       └── ...
├── val/
│   ├── cat/
│   │   └── ...
│   ├── dog/
│   │   └── ...
│   └── bird/
│       └── ...
└── test/      # 可选
    ├── cat/
    ├── dog/
    └── bird/

核心规则:

  • 每个类别对应一个文件夹,文件夹名即为类别名
  • 至少需要trainval两个子集
  • 支持jpg、jpeg、png、bmp等常见图片格式
  • 建议训练集与验证集的比例为8:2或7:3

2.1 快速准备示例数据集

如果你手头没有合适的数据,可以使用Ultralytics自带的示例数据快速体验:

python 复制代码
from ultralytics import YOLO

# 使用内置的ImageNet子集进行测试
model = YOLO("yolov8n-cls.pt")
model.train(data="imagenet10", epochs=5, imgsz=224)

三、模型选择

YOLOv8提供了多种规模的分类模型,可以根据你的硬件条件和精度需求选择:

模型 大小(MB) ImageNet Top-1 Acc 参数量
yolov8n-cls.pt 3.7 69.0% 2.7M
yolov8s-cls.pt 11.1 74.2% 6.5M
yolov8m-cls.pt 27.3 77.5% 17.4M
yolov8l-cls.pt 56.2 79.0% 37.0M
yolov8x-cls.pt 91.4 79.6% 57.4M
  • n (Nano):适合快速验证和资源受限场景
  • s (Small):在精度与速度间取得平衡
  • m/l/x:追求更高精度,需要更多计算资源

建议 :初次尝试建议从yolov8s-cls.pt开始,它在精度和训练速度之间有较好的平衡。


四、开始训练

4.1 命令行方式

最简单的训练方式是通过命令行:

bash 复制代码
yolo classify train data=./dataset model=yolov8s-cls.pt epochs=100 imgsz=224 batch=32

4.2 Python脚本方式

对于更灵活的控制,推荐使用Python脚本:

python 复制代码
from ultralytics import YOLO

# 加载预训练模型
model = YOLO("yolov8s-cls.pt")

# 开始训练
results = model.train(
    data="./dataset",       # 数据集路径
    epochs=100,             # 训练轮数
    imgsz=224,              # 输入图像尺寸
    batch=32,               # 批大小
    lr0=0.001,              # 初始学习率
    lrf=0.01,               # 最终学习率因子
    optimizer="AdamW",      # 优化器
    device=0,               # GPU设备编号,"cpu"表示使用CPU
    workers=8,              # 数据加载线程数
    project="runs/classify",# 输出目录
    name="my_exp",          # 实验名称
    pretrained=True,        # 使用预训练权重
    augment=True,           # 数据增强
)

4.3 关键训练参数详解

参数 说明 推荐值
epochs 训练轮数 50~200
imgsz 输入图像尺寸 224或640
batch 批大小 根据GPU显存调整,16/32/64
lr0 初始学习率 0.001
optimizer 优化器 AdamW / SGD
patience 早停耐心值 20~50
augment 是否启用数据增强 True

小技巧:如果训练过程中验证集loss连续多个epoch不再下降,模型会自动触发早停(Early Stopping),无需手动停止。


五、训练过程监控

训练启动后,Ultralytics会将日志和指标保存到runs/classify/my_exp/目录下。

5.1 终端输出

训练过程中终端会实时打印每个epoch的指标:

复制代码
Epoch    GPU_mem   box_loss   cls_loss   Instances  Size
  1/100     3.2G     0.8234     0.5432         128  224:  ████████████████ 100%
  2/100     3.2G     0.7621     0.4891         128  224:  ████████████████ 100%
  ...

5.2 使用TensorBoard可视化

bash 复制代码
tensorboard --logdir runs/classify

在浏览器中打开http://localhost:6006,即可查看loss曲线、准确率曲线等。

5.3 训练结果文件

训练完成后,输出目录结构如下:

复制代码
runs/classify/my_exp/
├── weights/
│   ├── best.pt        # 最佳模型权重
│   └── last.pt        # 最后一轮权重
├── results.csv        # 训练指标记录
├── results.png        # 训练曲线图
├── confusion_matrix.png  # 混淆矩阵
└── args.yaml          # 训练参数配置

六、模型验证与评估

训练完成后,使用验证集评估模型性能:

python 复制代码
from ultralytics import YOLO

# 加载最佳模型
model = YOLO("runs/classify/my_exp/weights/best.pt")

# 在验证集上评估
metrics = model.val(data="./dataset")

# 打印关键指标
print(f"Top-1 Accuracy: {metrics.top1:.4f}")
print(f"Top-5 Accuracy: {metrics.top5:.4f}")

关键指标说明:

  • Top-1 Accuracy:预测概率最高的类别与真实标签一致的比率
  • Top-5 Accuracy:预测概率前5的类别中包含真实标签的比率
  • 混淆矩阵:直观展示各类别的分类情况及混淆关系

七、模型推理

7.1 单张图片推理

python 复制代码
from ultralytics import YOLO

model = YOLO("runs/classify/my_exp/weights/best.pt")

# 单张图片推理
results = model("test_image.jpg")

# 打印结果
for result in results:
    probs = result.probs
    print(f"预测类别: {probs.top1}")
    print(f"类别名称: {result.names[probs.top1]}")
    print(f"置信度: {probs.top1conf:.4f}")
    # Top-5预测
    top5_indices = probs.top5
    top5_confs = probs.top5conf
    for idx, conf in zip(top5_indices, top5_confs):
        print(f"  {result.names[idx]}: {conf:.4f}")

7.2 批量推理

python 复制代码
# 对整个文件夹的图片进行推理
results = model.predict(source="./test_images/", save=True)

# 或者对视频流推理
results = model.predict(source=0, show=True)  # 使用摄像头

7.3 命令行推理

bash 复制代码
yolo classify predict model=runs/classify/my_exp/weights/best.pt source=./test_images/ save=True

八、模型导出与部署

训练好的模型可以导出为多种格式,方便在不同平台上部署:

python 复制代码
from ultralytics import YOLO

model = YOLO("runs/classify/my_exp/weights/best.pt")

# 导出为ONNX格式
model.export(format="onnx")

# 导出为TensorRT格式(需要GPU环境)
# model.export(format="engine")

# 导出为CoreML格式(iOS部署)
# model.export(format="coreml")

支持的导出格式:

格式 参数 适用场景
ONNX onnx 通用部署、跨平台
TensorRT engine NVIDIA GPU高性能推理
CoreML coreml iOS/macOS设备
OpenVINO openvino Intel硬件加速
TFLite tflite 移动端/嵌入式

九、常见问题与调优建议

9.1 过拟合

现象:训练集准确率很高,验证集准确率上不去。

解决方案

  • 增加训练数据量
  • 启用更强的数据增强
  • 使用更小的模型
  • 增加Dropout
  • 减小学习率

9.2 欠拟合

现象:训练集和验证集准确率都很低。

解决方案

  • 使用更大的模型
  • 延长训练epoch数
  • 减小正则化强度
  • 增大学习率

9.3 类别不均衡

解决方案

  • 采样平衡:对少数类过采样,多数类欠采样
  • 数据增强:对少数类应用更多增强策略
  • 使用focal loss等改进损失函数

9.4 GPU显存不足

解决方案

  • 减小batch_size
  • 减小imgsz
  • 使用混合精度训练(amp=True
  • 使用梯度累积(accumulate参数)

总结

本文完整介绍了使用YOLO进行图像分类任务的训练流程:

  1. 环境搭建:一行pip命令即可安装Ultralytics
  2. 数据准备:按类别文件夹组织数据,结构清晰直观
  3. 模型选择:从Nano到XLarge,覆盖不同资源场景
  4. 训练配置:命令行或Python脚本两种方式,灵活便捷
  5. 监控评估:TensorBoard可视化 + 混淆矩阵分析
  6. 推理部署:多格式导出,覆盖云端到移动端

YOLO的分类功能继承了其检测任务的设计哲学------简洁高效。无论你是深度学习初学者还是经验丰富的开发者,都可以快速上手。只需准备好按类别组织的数据集,几行代码就能启动训练,非常适合快速验证想法和构建分类原型。

相关推荐
下午写HelloWorld1 小时前
同态加密(Homomorphic Encryption, HE)
人工智能·算法·密码学·同态加密
尚可签1 小时前
小烟改写工具:让文字表达更自然,让文档改写更高效
人工智能
小何code1 小时前
【Python零基础入门】第10篇:Python列表方法与应用实例
数据库·人工智能·python
神仙别闹1 小时前
基于 Python 实现 ANN 与 KNN 的图像分类
开发语言·python·分类
聚名网1 小时前
中文域名深化实体应用,稳步对接智能交互场景
人工智能·经验分享
决战灬1 小时前
llamaIndex BatchEvalRunner(一)
人工智能
fl1768311 小时前
智慧医疗腹腔断层CT脏器识别分割数据集labelme格式1030张4类别
人工智能
数智工坊1 小时前
周志华《Machine Learning》学习笔记--第七章--贝叶斯分类器
人工智能·笔记·神经网络·学习·机器学习
酉鬼女又兒1 小时前
零基础入门计算机网络:物理层核心知识全解——传输方式分类、编码调制原理与信道极限容量计算
网络·计算机网络·考研·职场和发展·分类·数据挖掘·php