无脑014——linux系统,制作coco(json)格式数据集,使用mmdetection训练自己的数据集

电脑,linux,RTX 3090 cuda 11.2

1.制作coco(json)格式数据集

这里我们使用的标注软件是:labelimg

选择voc格式进行标注,标注之后使用以下代码,把voc格式转换成coco格式,注意最后的路径

这个代码只能一次生成一个train.json文件,需要再打标签,生成一个val.json文件,就可以训练了

bash 复制代码
import sys
import os
import json
import xml.etree.ElementTree as ET
import glob

START_BOUNDING_BOX_ID = 1
PRE_DEFINE_CATEGORIES = {"cat": 1, "person": 2}
# If necessary, pre-define category and its id
#  PRE_DEFINE_CATEGORIES = {"aeroplane": 1, "bicycle": 2, "bird": 3, "boat": 4,
#  "bottle":5, "bus": 6, "car": 7, "cat": 8, "chair": 9,
#  "cow": 10, "diningtable": 11, "dog": 12, "horse": 13,
#  "motorbike": 14, "person": 15, "pottedplant": 16,
#  "sheep": 17, "sofa": 18, "train": 19, "tvmonitor": 20}


def get(root, name):
    vars = root.findall(name)
    return vars


def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise ValueError("Can not find %s in %s." % (name, root.tag))
    if length > 0 and len(vars) != length:
        raise ValueError(
            "The size of %s is supposed to be %d, but is %d."
            % (name, length, len(vars))
        )
    if length == 1:
        vars = vars[0]
    return vars


def get_filename_as_int(filename):
    try:
        filename = filename.replace("\\", "/")
        filename = os.path.splitext(os.path.basename(filename))[0]
        return int(filename)
    except:
        raise ValueError("Filename %s is supposed to be an integer." % (filename))


def get_categories(xml_files):
    """Generate category name to id mapping from a list of xml files.
    
    Arguments:
        xml_files {list} -- A list of xml file paths.
    
    Returns:
        dict -- category name to id mapping.
    """
    classes_names = []
    for xml_file in xml_files:
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for member in root.findall("object"):
            classes_names.append(member[0].text)
    classes_names = list(set(classes_names))
    classes_names.sort()
    return {name: i for i, name in enumerate(classes_names)}


def convert(xml_files, json_file):
    json_dict = {"images": [], "type": "instances", "annotations": [], "categories": []}
    if PRE_DEFINE_CATEGORIES is not None:
        categories = PRE_DEFINE_CATEGORIES
    else:
        categories = get_categories(xml_files)
    bnd_id = START_BOUNDING_BOX_ID
    for xml_file in xml_files:
        tree = ET.parse(xml_file)
        root = tree.getroot()
        path = get(root, "path")
        if len(path) == 1:
            filename = os.path.basename(path[0].text)
        elif len(path) == 0:
            filename = get_and_check(root, "filename", 1).text
        else:
            raise ValueError("%d paths found in %s" % (len(path), xml_file))
        ## The filename must be a number
        image_id = get_filename_as_int(filename)
        size = get_and_check(root, "size", 1)
        width = int(get_and_check(size, "width", 1).text)
        height = int(get_and_check(size, "height", 1).text)
        image = {
            "file_name": filename,
            "height": height,
            "width": width,
            "id": image_id,
        }
        json_dict["images"].append(image)
        ## Currently we do not support segmentation.
        #  segmented = get_and_check(root, 'segmented', 1).text
        #  assert segmented == '0'
        for obj in get(root, "object"):
            category = get_and_check(obj, "name", 1).text
            if category not in categories:
                new_id = len(categories)
                categories[category] = new_id
            category_id = categories[category]
            bndbox = get_and_check(obj, "bndbox", 1)
            xmin = int(get_and_check(bndbox, "xmin", 1).text) - 1
            ymin = int(get_and_check(bndbox, "ymin", 1).text) - 1
            xmax = int(get_and_check(bndbox, "xmax", 1).text)
            ymax = int(get_and_check(bndbox, "ymax", 1).text)
            assert xmax > xmin
            assert ymax > ymin
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            ann = {
                "area": o_width * o_height,
                "iscrowd": 0,
                "image_id": image_id,
                "bbox": [xmin, ymin, o_width, o_height],
                "category_id": category_id,
                "id": bnd_id,
                "ignore": 0,
                "segmentation": [],
            }
            json_dict["annotations"].append(ann)
            bnd_id = bnd_id + 1

    for cate, cid in categories.items():
        cat = {"supercategory": "none", "id": cid, "name": cate}
        json_dict["categories"].append(cat)

    os.makedirs(os.path.dirname(json_file), exist_ok=True)
    json_fp = open(json_file, "w")
    json_str = json.dumps(json_dict)
    json_fp.write(json_str)
    json_fp.close()


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Convert Pascal VOC annotation to COCO format."
    )
    parser.add_argument("xml_dir", nargs='?',default=r'G:\bsh\dataset\dingzi\Annotations',help="Directory path to xml files.", type=str)
    parser.add_argument("json_file",nargs='?', default=r'G:\bsh\dataset\dingzi\coco_json\train.json',help="Output COCO format json file.", type=str)
    args = parser.parse_args()
    xml_files = glob.glob(os.path.join(args.xml_dir, "*.xml"))

    # If you want to do train/test split, you can pass a subset of xml files to convert function.
    print("Number of xml files: {}".format(len(xml_files)))
    convert(xml_files, args.json_file)
    print("Success: {}".format(args.json_file))
    

2.配置mmdet环境

去github mmdet官网下载zip文件,然后解压

https://github.com/open-mmlab/mmdetection

bash 复制代码
conda create -n xcb_mmdet31 python=3.8 -y
conda activate xcb_mmdet31
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
pip install -U openmim  #最后安装成功 0.3.9
mim install mmengine    #最后安装成功 0.7.1
mim install "mmcv>=2.0.0"   # 2.0.0rc4
cd mmdetection
pip install -v -e .  #最后安装成功 3.1.0

可以使用代码测试环境安装是否成功:

bash 复制代码
mim download mmdet --config rtmdet_tiny_8xb32-300e_coco --dest .

python demo/image_demo.py demo/demo.jpg rtmdet_tiny_8xb32-300e_coco.py --weights rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth --device cpu

然后修改配置文件

此次准备使用的是RTMDet模型

3.修改配置文件

3.1 修改文件1

修改configs/rtmdet/rtmdet_l_8xb32-300e_coco.py文件中的32行

bash 复制代码
        num_classes=1,

3.2 修改文件2

修改configs/base /datasets/coco_detection.py文件中

第3行

bash 复制代码
data_root = 'data/dingzi/'

第46行47行

bash 复制代码
        ann_file='coco_json/train.json',
        data_prefix=dict(img='images/train/'),

第60行第61行

bash 复制代码
        ann_file='coco_json/val.json',
        data_prefix=dict(img='images/val/'),

第69行

bash 复制代码
    ann_file=data_root + 'coco_json/val.json',

3.3 修改文件3

修改mmdet/datasets/coco.py文件

第19行

bash 复制代码
        ('dingzi', 'heidong'),

3.4 修改文件4

mmdet/evaluation/functional/class_names.py

第75行

bash 复制代码
        'dingzi', 'heidong'

3.5修改完成,重新编译

执行

bash 复制代码
python setup.py install

4.开始训练

bash 复制代码
 python tools/train.py configs/rtmdet/rtmdet_l_8xb32-300e_coco.py --work-dir output

训练结束,显示结果:

bash 复制代码
09/08 17:45:03 - mmengine - INFO - Epoch(val) [300][1/1]  
coco/bbox_mAP: 0.3930  coco/bbox_mAP_50: 0.8810  coco/bbox_mAP_75: 0.2470  
coco/bbox_mAP_s: 0.3030  coco/bbox_mAP_m: 0.4620  coco/bbox_mAP_l: -1.0000  
data_time: 0.0520  time: 0.1187

因为我只使用了3张图片,所以效果不好,接下来进行推理测试

5.推理测试

bash 复制代码
python demo/image_demo.py data/dingzi/images/train/ output/rtmdet_l_8xb32-300e_coco.py --weights output/epoch_300.pth 

效果达到预期
接下来,要继续研究如何使用训练后的模型进行预测标签的保存选项,自动标注

先去吃饭

相关推荐
Spring-wind1 小时前
【linux】kill命令
linux
dreamer2921 小时前
21、Tomato
linux·安全·web安全·网络安全·系统安全
小小的木头人2 小时前
Docker vs. containerd 深度剖析容器运行时
运维·docker·容器
Data 3172 小时前
Shell脚本编程基础(二)
大数据·linux·运维·数据仓库·sql·centos·bash
营赢盈英2 小时前
Why is OpenAI image generation Api returning 400 bad request in Unity?
ai·json·openai api·post·unitywebrequest·unitygameengine
一晌小贪欢2 小时前
Python基础知识——字典排序(不断补充)
python·json·python基础·字典·字典排序·python学习
it技术分享just_free3 小时前
基于 K8S kubernetes 的常见日志收集方案
linux·运维·docker·云原生·容器·kubernetes·k8s
bmseven3 小时前
windows远程桌面连接ubuntu
linux·windows·ubuntu
aidroid3 小时前
git github仓库管理
linux·运维·docker
学习3人组4 小时前
集群服务器主机实现主机名与IP绑定
运维·服务器·tcp/ip