yolov8对新的数据集自动标注

项目地址

https://github.com/ultralytics/ultralytics

极简运行效果

获取模型bbox的极简demo

有时候是想要获取yolo检测的bbox框。

python 复制代码
import random
import cv2 as cv
from ultralytics import YOLO

# model = YOLO("yolov8m.yaml")
# model = YOLO("yolov8m.pt")
model = YOLO("yolov8x.pt")

coco_label = ["person", "bicycle", "car", "motorcycle", "airplane", 
              "bus", "train", "truck", "boat", "traffic light", "fire hydrant", 
              "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", 
              "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", 
              "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", 
              "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", 
              "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", 
              "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", 
              "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
                "couch", "potted plant", "bed", "dining table", "toilet", "tv", 
                "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
                  "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", 
                  "scissors", "teddy bear", "hair drier", "toothbrush"]

def generate_colors(num_colors):
    colors = []
    for _ in range(num_colors):
        r = random.randint(0, 255)
        g = random.randint(0, 255)
        b = random.randint(0, 255)
        colors.append((r, g, b))
    return colors

coco_colors = generate_colors(len(coco_label))

results = model("/media/xp/data/image/sample/person2.jpg")
for r in results:
    # print(r.boxes)
    img = cv.imread(r.path)
    for box in r.boxes:
        x1, y1, x2, y2, score, class_id = box.data[0]
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
        cv.rectangle(img, (x1, y1), (x2, y2), coco_colors[int(class_id)], 2)
        cv.putText(img, coco_label[int(class_id)], (x1, y1), cv.FONT_HERSHEY_SIMPLEX, 0.5, coco_colors[int(class_id)], 2)
    cv.imshow("img", img)
    key = cv.waitKey(0)
    if key == 27:
        break

用yolov8半自动标注

  • auto_label.py
python 复制代码
import cv2 as cv
from ultralytics import YOLO
import os
import random

coco_label = ["person", "bicycle", "car", "motorcycle", "airplane", 
              "bus", "train", "truck", "boat", "traffic light", "fire hydrant", 
              "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", 
              "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", 
              "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", 
              "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", 
              "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", 
              "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", 
              "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
                "couch", "potted plant", "bed", "dining table", "toilet", "tv", 
                "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
                  "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", 
                  "scissors", "teddy bear", "hair drier", "toothbrush"]

def generate_colors(num_colors):
    colors = []
    for _ in range(num_colors):
        r = random.randint(0, 255)
        g = random.randint(0, 255)
        b = random.randint(0, 255)
        colors.append((r, g, b))
    return colors

coco_colors = generate_colors(len(coco_label))

def get_all_image_files(path, image_path_list,recursive=False):
    '''
    Get all image files in the path.
    Args:
        path: the path to search.
        image_path_list: the list to store the image path.
        recursive: whether to search the path recursively.
    '''
    if not os.path.exists(path):
        print("The path does not exist.")
        return
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith('.jpg') or file.endswith('.png') or file.endswith('.jpeg'):
                image_path_list.append(os.path.join(root, file))
        if not recursive:
            break

def predict_image(model, image_path, score_threshold=0.5):
    '''
    Predict the image.
    Args:
        model: the model to predict the image.
        image_path: the path of the image.
        score_threshold: the threshold of the score.

    Returns:
        bbox: the bounding box of the image. The format is [x1, y1, x2, y2, class_name]
    '''
    bbox = []
    results = model(image_path)
    for r in results:
        img = cv.imread(r.path)
        for box in r.boxes:
            x1, y1, x2, y2, score, class_id = box.data[0]
            if score < score_threshold:
                continue
            # x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
            bbox.append([x1, y1, x2, y2, coco_label[int(class_id)]])
    return bbox

def autolabel(src_image_root, save_root, calss_names=[], score_threshold=0.5):
    '''
    Autolabel the image.
    Args:
        src_image_root: the root of the source image.
        save_root: the root to save the image.
        calss_names: the class names to label.
        score_threshold: the threshold of the score.
    '''
    model = YOLO("yolov8x.pt")
    image_save_root = os.path.join(save_root, "images")
    label_save_root = os.path.join(save_root, "labels")
    os.makedirs(image_save_root, exist_ok=True)
    os.makedirs(label_save_root, exist_ok=True)
    image_path_list = []
    get_all_image_files(src_image_root, image_path_list, recursive=True)

    current_num = 0
    total_num = len(image_path_list)
    for image_path in image_path_list:
        bbox = predict_image(model, image_path, score_threshold=score_threshold)
        img = cv.imread(image_path)
        if img is None:
            print("The image is None.")
            continue
        
        file_name  = f"{str(current_num).zfill(6)}"
        image_save_path = os.path.join(image_save_root, file_name + ".jpg")
        label_save_path = os.path.join(label_save_root, file_name + ".txt")
        cv.imwrite(image_save_path, img)
        with open(label_save_path, "w") as f:
            for box in bbox:
                x1, y1, x2, y2, class_name = box
                if class_name not in calss_names:
                    continue
                cx = (x1 + x2) // 2
                cy = (y1 + y2) // 2
                w = x2 - x1
                h = y2 - y1
                
                # normalize the value
                cx /= img.shape[1]
                cy /= img.shape[0]
                w /= img.shape[1]
                h /= img.shape[0]
                f.write(f"{class_name} {cx} {cy} {w} {h}\n")

        current_num += 1
        
        print(f"{current_num}/{total_num}")
        print(f"image_save_path: {image_save_path} \n label_save_path: {label_save_path}")
        


def demo_of_autolabel():
    src_image_root = "/media/xp/data/image/sentinel/raw/test_data/del/"
    save_root = "/media/xp/data/image/sentinel/raw/test_data/del/dataset"
    calss_names = ["person", "dog", "cat"]
    score_threshold = 0.5
    autolabel(src_image_root, save_root, calss_names, score_threshold)

def simple_demo():
    model = YOLO("yolov8x.pt")
    results = model("/media/xp/data/image/sample/person2.jpg")
    for r in results:
        # print(r.boxes)
        img = cv.imread(r.path)
        for box in r.boxes:
            x1, y1, x2, y2, score, class_id = box.data[0]
            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
            cv.rectangle(img, (x1, y1), (x2, y2), coco_colors[int(class_id)], 2)
            cv.putText(img, coco_label[int(class_id)], (x1 , y1 ), cv.FONT_HERSHEY_SIMPLEX, 0.5, coco_colors[int(class_id)], 2)
        cv.imshow("img", img)
        key = cv.waitKey(0)
        if key == 27:
            break

if __name__ == "__main__":
    # simple_demo()
    demo_of_autolabel()
    
  • visualize_yolo_dataset.py
python 复制代码
import os
import cv2 as cv
import numpy as np
import json


def get_all_image_files(path, image_path_list,recursive=False):
    '''
    Get all image files in the path.
    Args:
        path: the path to search.
        image_path_list: the list to store the image path.
        recursive: whether to search the path recursively.
    '''
    if not os.path.exists(path):
        print("The path does not exist.")
        return
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith('.jpg') or file.endswith('.png') or file.endswith('.jpeg'):
                image_path_list.append(os.path.join(root, file))
        if not recursive:
            break


def get_all_labels_from_image_lists(image_path_list, label_path_list):
    
    for image_path in image_path_list:
        image_folder = os.path.dirname(image_path) # xx/yy/images
        label_folder = os.path.join(os.path.dirname(image_folder), "labels") # xx/yy/labels
        image_name = os.path.basename(image_path)
        label_name = os.path.splitext(image_name)[0] + ".txt"
        label_path = os.path.join(label_folder, label_name)
        if not os.path.exists(label_path):
            print(f"The label path does not exist, path:{label_path}.")
            continue
        label_path_list.append(label_path)

def get_bbox_from_yolo_txt(yolo_txt_path, image_path=None):
    '''
    Parse the yolo txt file.
    Args:
        yolo_txt_path: the path of the yolo txt file.
    Returns:
        bbox_list: the list of bbox. bbox is a list of [x1, y1, x2, y2, label].
    '''
    if not os.path.exists(yolo_txt_path):
        print(f"The yolo txt path does not exist, path:{yolo_txt_path}.")
        return
    if image_path is not None and not os.path.exists(image_path):
        print(f"The image path does not exist, path:{image_path}.")
        return
    if image_path is not None:
        image = cv.imread(image_path)
        image_h, image_w, _ = image.shape
    bbox_list = []
    with open(yolo_txt_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip()
            label , x1, y1, w, h = line.split(" ")
            x1 = float(x1)
            y1 = float(y1)
            w = float(w)
            h = float(h)
            x1 = x1 - w / 2
            y1 = y1 - h / 2
            x2 = x1 + w
            y2 = y1 + h
            x1 = x1 * image_w
            y1 = y1 * image_h
            x2 = x2 * image_w
            y2 = y2 * image_h
            bbox_list.append([x1, y1, x2, y2, label])
    return bbox_list


def draw_bbox(image, bbox_list, color=(0, 255, 0)):
    '''
    Draw the bbox on the image.
    Args:
        image_path: the path of the image.
        bbox_list: the list of bbox. bbox is a list of [x1, y1, x2, y2, label].
        save_path: the path to save the image.
    '''
   
    for bbox in bbox_list:
        x1, y1, x2, y2, label = bbox
        x1 = int(x1)
        y1 = int(y1)
        x2 = int(x2)
        y2 = int(y2)
        cv.rectangle(image, (x1, y1), (x2, y2), color, 2)
        cv.putText(image, label, (x1 + 5, y1 + 10), cv.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
    return image

def visualize_dataset(dataset_path):
    image_path_list = []
    label_path_list = []
    get_all_image_files(dataset_path, image_path_list, recursive=True)
    get_all_labels_from_image_lists(image_path_list, label_path_list)

    assert len(image_path_list) == len(label_path_list) , f"The number of image and label is not equal. image:{len(image_path_list)}, label:{len(label_path_list)}."
    for i in range(len(image_path_list)):
        image_path = image_path_list[i]
        label_path = label_path_list[i]
        bbox_list = get_bbox_from_yolo_txt(label_path, image_path)
        print(f"image_path:{image_path}, label_path:{label_path}, bbox_list:{bbox_list}")
        image = cv.imread(image_path)
        view = draw_bbox(image, bbox_list)
        cv.imshow("view", view)
        key = cv.waitKey(0)
        if key == 27:
            break

def demo_of_visualize_dataset():
    dataset_dir = "/media/xp/data/image/sentinel/raw/test_data/del/dataset"
    visualize_dataset(dataset_dir)

if __name__ == "__main__":
    demo_of_visualize_dataset()
                

使用流程,先运行auto_label.py,把所有图片预测保存到一个路径下面,目前保存为jpg图片。然后用visualize_yolo_dataset.py检查自动标注的结果。从结果看的话基本bbox还是比较准的,但是会有漏,那就需要自己手动调整了。另外这里的txt里面的label用字符串,而不是0,1,2,3...,后面写个脚本转换为自己数据集的class_id就ok了。

下面是自动标注后的可视化结果。

相关推荐
枫叶丹48 分钟前
【Qt开发】多元素类控件(二)-> QTableWidget
开发语言·qt
bin915310 分钟前
当AI开始‘映射‘用户数据:初级Python开发者的创意‘高阶函数‘如何避免被‘化简‘?—— 老码农的函数式幽默
开发语言·人工智能·python·工具·ai工具
Nebula_g32 分钟前
Java哈希表入门详解(Hash)
java·开发语言·学习·算法·哈希算法·初学者
努力努力再努力wz34 分钟前
【C++进阶系列】:万字详解unordered_set和unordered_map,带你手搓一个哈希表!(附模拟实现unordered_set和unordered_map的源码)
java·linux·开发语言·数据结构·数据库·c++·散列表
励志不掉头发的内向程序员38 分钟前
【STL库】哈希表的原理 | 哈希表模拟实现
开发语言·c++·学习·散列表
万粉变现经纪人44 分钟前
如何解决 pip install -r requirements.txt 私有仓库认证失败 401 Unauthorized 问题
开发语言·python·scrapy·flask·beautifulsoup·pandas·pip
量子炒饭大师1 小时前
收集飞花令碎片——C语言字符函数与字符串函数
c语言·开发语言
懂得节能嘛.1 小时前
【设计模式】Java规则树重构复杂业务逻辑
java·开发语言·设计模式
syt_biancheng1 小时前
Qt--命名,快捷键及坐标系
开发语言·qt