Datawhale 夏令营 Task1:跑通YOLO方案baseline!

YOLO数据处理

一.YOLO数据格式

YOLO数据格式为 <class> <x_center> <y_center> <width> <height>

二.制作数据集

1.新建文件夹及配置文件

python 复制代码
if not os.path.exists('yolo-dataset/'):
    os.mkdir('yolo-dataset/')
if not os.path.exists('yolo-dataset/train'):
    os.mkdir('yolo-dataset/train')
if not os.path.exists('yolo-dataset/val'):
    os.mkdir('yolo-dataset/val')

dir_path = os.path.abspath('./') + '/'

# 需要按照你的修改path
with open('yolo-dataset/yolo.yaml', 'w', encoding='utf-8') as up:
    up.write(f'''
path: {dir_path}/yolo-dataset/
train: train/
val: val/

names:
    0: 非机动车违停
    1: 机动车违停
    2: 垃圾桶满溢
    3: 违法经营
''')

2.数据转化

(1) 原始数据集

视频数据为mp4格式,标注文件为json格式,每个视频对应一个json文件。

json文件的内容是每帧检测到的违规行为,包括以下字段:

  • frame_id:违规行为出现的帧编号
  • event_id:违规行为ID
  • category:违规行为类别
  • bbox:检测到的违规行为矩形框的坐标,[xmin,ymin,xmax,ymax]形式

标注示例如下:

json 复制代码
[
  {
   "frame_id": 20,
   "event_id": 1,
   "category": "机动车违停",
   "bbox": [200, 300, 280, 400]
  },
  {
   "frame_id": 20,
   "event_id": 2,
   "category": "机动车违停",
   "bbox": [600, 500, 720, 560]
  },
  {
   "frame_id": 30,
   "event_id": 3,
   "category": "垃圾桶满溢",
   "bbox": [400, 500, 600, 660]
  }
 ]

(2) 数据格式转化

遍历读取每个视频的每一帧,保存视频的每一个帧及根据帧的id找出对应的标签写入对应的txt文件。

json文件标注[xmin,ymin,xmax,ymax],而YOLO所需格式为【x_center,y_center,width,height】格式,因此在写入txt文件前需要进行格式转化

python 复制代码
train_annos = glob.glob('训练集(有标注第一批)/标注/*.json')
train_videos = glob.glob('训练集(有标注第一批)/视频/*.mp4')
train_annos.sort(); train_videos.sort()

category_labels = ["非机动车违停", "机动车违停", "垃圾桶满溢", "违法经营"]

for anno_path, video_path in zip(train_annos[:5], train_videos[:5]):
    print(video_path)
    anno_df = pd.read_json(anno_path)
    cap = cv2.VideoCapture(video_path)
    frame_idx = 0 
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        img_height, img_width = frame.shape[:2]
        
        frame_anno = anno_df[anno_df['frame_id'] == frame_idx]
        cv2.imwrite('./yolo-dataset/train/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.jpg', frame)

        if len(frame_anno) != 0:
            with open('./yolo-dataset/train/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.txt', 'w') as up:
                for category, bbox in zip(frame_anno['category'].values, frame_anno['bbox'].values):
                    category_idx = category_labels.index(category)
                    
                    x_min, y_min, x_max, y_max = bbox
                    x_center = (x_min + x_max) / 2 / img_width
                    y_center = (y_min + y_max) / 2 / img_height
                    width = (x_max - x_min) / img_width
                    height = (y_max - y_min) / img_height

                    if x_center > 1:
                        print(bbox)
                    up.write(f'{category_idx} {x_center} {y_center} {width} {height}\n')
        
        frame_idx += 1

三. 模型训练

python 复制代码
from ultralytics import YOLO
model = YOLO("yolov8n.pt")
results = model.train(data="yolo-dataset/yolo.yaml", epochs=2, imgsz=1080, batch=16)

四. 模型输出

根据result.boxes.xyxy 的格式为【x_min,y_min,x_max,y_max】,因此保存json时无须转换。

python 复制代码
from ultralytics import YOLO
model = YOLO("runs/detect/train/weights/best.pt")
import glob

for path in glob.glob('测试集/*.mp4'):
    submit_json = []
    results = model(path, conf=0.05, imgsz=1080,  verbose=False)
    for idx, result in enumerate(results):
        boxes = result.boxes  # Boxes object for bounding box outputs
        masks = result.masks  # Masks object for segmentation masks outputs
        keypoints = result.keypoints  # Keypoints object for pose outputs
        probs = result.probs  # Probs object for classification outputs
        obb = result.obb  # Oriented boxes object for OBB outputs

        if len(boxes.cls) == 0:
            continue
        
        xyxy = boxes.xyxy.data.cpu().numpy().round()
        cls = boxes.cls.data.cpu().numpy().round()
        conf = boxes.conf.data.cpu().numpy()
        for i, (ci, xy, confi) in enumerate(zip(cls, xyxy, conf)):
            submit_json.append(
                {
                    'frame_id': idx,
                    'event_id': i+1,
                    'category': category_labels[int(ci)],
                    'bbox': list([int(x) for x in xy]),
                    "confidence": float(confi)
                }
            )

    with open('./result/' + path.split('/')[-1][:-4] + '.json', 'w', encoding='utf-8') as up:
        json.dump(submit_json, up, indent=4, ensure_ascii=False)
相关推荐
傻啦嘿哟39 分钟前
如何使用 Python 开发一个简单的文本数据转换为 Excel 工具
开发语言·python·excel
B站计算机毕业设计超人1 小时前
计算机毕业设计SparkStreaming+Kafka旅游推荐系统 旅游景点客流量预测 旅游可视化 旅游大数据 Hive数据仓库 机器学习 深度学习
大数据·数据仓库·hadoop·python·kafka·课程设计·数据可视化
GOTXX1 小时前
基于Opencv的图像处理软件
图像处理·人工智能·深度学习·opencv·卷积神经网络
IT古董1 小时前
【人工智能】Python在机器学习与人工智能中的应用
开发语言·人工智能·python·机器学习
湫ccc2 小时前
《Python基础》之pip换国内镜像源
开发语言·python·pip
hakesashou2 小时前
Python中常用的函数介绍
java·网络·python
菜鸟的人工智能之路2 小时前
极坐标气泡图:医学数据分析的可视化新视角
python·数据分析·健康医疗
菜鸟学Python2 小时前
Python 数据分析核心库大全!
开发语言·python·数据挖掘·数据分析
小白不太白9502 小时前
设计模式之 责任链模式
python·设计模式·责任链模式
喜欢猪猪2 小时前
Django:从入门到精通
后端·python·django