YOLOv8猫狗检测:从SwanLab可视化训练到Gradio Demo网站

基于YOLO模型在自定义数据 上做训练,实现对特定目标的识别和检测,是CV领域非常经典的任务,也是AI项目落地最热门的方向之一。

这篇文章我将带大家使用Ultralytics、SwanLab、Gradio 这两个开源工具,完成从数据集准备、代码编写、可视化训练推理Demo的全过程。

观察了一下,中文互联网上似乎很少有针对自定义数据的,能直接跑起来的YOLO训练代码和教程,所以也希望这篇文章可以帮到在科研一线的大家~

1.环境安装

我们需要安装以下这4个Python库:

txt 复制代码
ultralytics
swanlab>=0.3.6
gradio

一键安装命令:

pip install ultralytics swanlab gradio

他们的作用分别是:

  1. Ultralytics:YOLO官方团队推出的CV训练与推理框架,不仅支持目标检测任务,还支持分割、姿态识别、分类等更多任务。本项目用Ultralytics作为训练框架。
  1. swanlab : 一个深度学习实验管理与训练可视化工具,由西安电子科技大学团队打造,官网, 融合了Weights & Biases与Tensorboard的特点,可以记录整个实验的超参数、指标、训练环境、Python版本等,并可视化图表,帮助你分析训练的表现。本项目用swanlab主要用于记录指标和可视化。
  1. gradio: HuggingFace出品的推理Demo构建工具,是深度学习最流行的Demo框架之一,可以用python代码轻松搭建网页。本项目用Gradio作为推理Demo框架。

整个项目的目录结构如下:

2.准备数据集

数据集这里我使用的是Kaggle上的Dog and Cat Detection数据集,包含3686张带标注的猫狗图像。

这里除了下载数据集以外,还需要对格式做处理,所以我把做好的数据集放到百度云(提取码: f238)里了,推荐大家直接下载。

下面重点介绍一下怎么让你的自定义数据集适配Ultralytics,掌握之后,几乎所有自定义数据集的处理就都学会了。

首先,Ultralytics推荐的数据集结构是这样的:

txt 复制代码
datasets
├── images
│   ├── train
│      ├── 00001.jpg
│      ├── ...
│   ├── val
│   ├── test
├── labels
│   ├── train
│      ├── 00001.txt
│      ├── ...
│   ├── val
│   ├── test
├── data.yaml

这里面是一个数据集文件夹,包含imageslabels两个文件夹和一个data.yaml配置文件:

  • images文件夹放图像,labels文件夹放标注文件,图像和标注文件的名称要一一对应
  • imageslabels文件夹下分别放train、val、test三个子文件夹,作为训练集、验证集和测试集
  • data.yaml的格式如下:
yaml 复制代码
path:  path/to/datasets # 这里填写你数据集所在的绝对路径
train: images/train
val: images/train
test: images/test

# 标签和对应的类别
names:
  0: cat
  1: dog
  • 标注文件的格式如下:
txt 复制代码
0 0.618 0.127 0.299 0.226
1 0.491 0.333 0.506 0.545

每一行的第一个数字代表标签,后续的四个数字是标注框的<x1> <y1> <x2> <y2>相对于图像shape的归一化值(或者说比例)。

标签文件中出现多行则代表着图像中有多个检测到的目标。

由于Kaggle的数据集并不是按YOLO的标注格式来的,所以需要写脚本进行处理,推荐直接从百度云(提取码: f238)下载。

这里我们只做train和val。最后处理好的数据集格式如下:

3.开始训练-train.py

在准备好数据集之后,最艰难的步骤就结束了------训练代码非常的简短,如下所示:

python 复制代码
from ultralytics import YOLO
from swanlab.integration.ultralytics import add_swanlab_callback
import swanlab

def main():
    swanlab.init(project="Cats_Dogs_Detection", experiment_name="YOLOv8n",)
    model = YOLO("yolov8n.pt")
    add_swanlab_callback(model)
    # 将下面的路径替换成你的绝对路径
    model.train(data="path/to/cats_dogs_dataset/data.yaml", epochs=5, imgsz=320, batch=32)

if __name__ == "__main__":
    main()

请将model.train中data参数的路径替换成你的data.yaml的绝对路径。

这里我们使用了yolov8n模型(6MB左右)训练5个epoch,batchsize为32,并使用swanlab进行训练可视化。

在运行训练脚本的时候,如果你是第一次使用swanlab,那么需要去swanlab官网注册一个账号,然后在用户设置界面复制API Key,然后在命令行输入swanlab login,粘贴API Key即可完成登录。

训练过程(可访问猫狗检测-SwanLab查看):

我这里做了个两个实验,分别使用yolov8n和yolov8s两个模型训练100个epoch,可以看到最终的结果,指标非常的不错:

训练好的模型会存放在Ultralytics自动生成的runs文件夹下。

4.推理代码-predict.py

python 复制代码
from ultralytics import YOLO

# 载入训练好的模型
model = YOLO("path/to/model.pt")
# 推理多张图像
results = model(["img1.jpg", "img2,jpg"], device="cpu")

# 保存结果图
for iter, result in enumerate(results):
    result.save(filename=f"result{iter}.jpg")

5.Gradio推理Demo-app.py

python 复制代码
import gradio as gr
from ultralytics import YOLO
from PIL import Image

# 加载预训练的 YOLO 模型
model = YOLO('path/to/best.pt')

def predict_image(image, conf_threshold, iou_threshold):
    # 使用模型进行推理
    results = model.predict(
        source=image, 
        conf=conf_threshold,
        iou=iou_threshold,
        show_labels=True,
        show_conf=True,
        imgsz=640,)
    
    # 提取结果
    for r in results:
        im_array = r.plot()
        im = Image.fromarray(im_array[..., ::-1])
    
    return im

# 定义 Gradio 接口
demo = gr.Interface(
    fn=predict_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"),
        gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU threshold"),
    ],
    outputs=gr.Image(type="pil", label="Result"),
    title="猫狗检测Demo",
    description="传一张带有猫狗的图像来进行推理。",
)

# 启动 Gradio 应用
if __name__ == "__main__":
    demo.launch()

运行程序后,会出现以下输出:

点开链接,出现猫狗检测的Demo网页:

用猫和狗的图片试试:

效果很完美!

至此,我们完成了用Ultralytics、SwanLab、Gradio三个开源工具训练1个猫狗检测模型的全部过程,更多想了解的可以参考相关链接或评论此文章。

如果有帮助,请点个赞和收藏吧~

5. 相关链接

相关推荐
XianxinMao8 分钟前
2024大模型双向突破:MoE架构创新与小模型崛起
人工智能·架构
Francek Chen19 分钟前
【深度学习基础】多层感知机 | 模型选择、欠拟合和过拟合
人工智能·pytorch·深度学习·神经网络·多层感知机·过拟合
pchmi1 小时前
C# OpenCV机器视觉:红外体温检测
人工智能·数码相机·opencv·计算机视觉·c#·机器视觉·opencvsharp
认知作战壳吉桔1 小时前
中国认知作战研究中心:从认知战角度分析2007年iPhone发布
大数据·人工智能·新质生产力·认知战·认知战研究中心
软件公司.乐学2 小时前
安全生产算法一体机定制
人工智能·安全
好评笔记2 小时前
AIGC视频扩散模型新星:Video 版本的SD模型
论文阅读·深度学习·机器学习·计算机视觉·面试·aigc·transformer
kcarly2 小时前
知识图谱都有哪些常见算法
人工智能·算法·知识图谱
dddcyy2 小时前
利用现有模型处理面部视频获取特征向量(3)
人工智能·深度学习
Fxrain2 小时前
[Computer Vision]实验三:图像拼接
人工智能·计算机视觉
2301_780356702 小时前
为医院量身定制做“旧改”| 全视通物联网智慧病房
大数据·人工智能·科技·健康医疗