新手学习yolov8目标检测小记2--对比实验中经典模型库MMDetection使用方法(使用自己的数据集训练,并转换为yolo格式评价指标)

一、按照步骤环境配置

python 复制代码
pip install timm==1.0.7 thop efficientnet_pytorch==0.7.1 einops grad-cam==1.4.8 dill==0.3.6 albumentations==1.4.11 pytorch_wavelets==1.3.0 tidecv PyWavelets -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install -U openmim -i https://pypi.tuna.tsinghua.edu.cn/simple
mim install mmengine -i https://pypi.tuna.tsinghua.edu.cn/simple
mim install "mmcv==2.1.0" -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install YOLO
pip install ultralytics
pip install -v -e.

二、自定义数据集放置

我这里已经将数据集按照训练集、验证集、测试集=8:1:1划分好,具体的存放目录结构如下图所示。其中test2017、train2017、val2017存放图片,testlabels、trainlabels、vallabels存放标注文件txt,数据集格式转换后,在annotations文件中。

YOLO格式转coco格式代码如下

python 复制代码
import os
import cv2
import json
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import argparse

# visdrone2019
classes = ['beibie1',
           'beibie2',
           'beibie3'
           ]

parser = argparse.ArgumentParser()
parser.add_argument('--image_path', default=r'E:\mmde\mmdetection-3.0.0\mmdetection-3.0.0\data\coco\val2017', type=str, help="path of images")
parser.add_argument('--label_path', default=r'E:\mmde\mmdetection-3.0.0\mmdetection-3.0.0\data\coco\vallabels', type=str, help="path of labels .txt")
parser.add_argument('--save_path', default='val.json', type=str,
                    help="if not split the dataset, give a path to a json file")
arg = parser.parse_args()


def yolo2coco(arg):
    print("Loading data from ", arg.image_path, arg.label_path)

    assert os.path.exists(arg.image_path)
    assert os.path.exists(arg.label_path)

    originImagesDir = arg.image_path
    originLabelsDir = arg.label_path
    # images dir name
    indexes = os.listdir(originImagesDir)

    dataset = {'categories': [], 'annotations': [], 'images': []}
    for i, cls in enumerate(classes, 0):
        dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'mark'})

    # 标注的id
    ann_id_cnt = 0
    for k, index in enumerate(tqdm(indexes)):
        # 支持 png jpg 格式的图片.
        txtFile = f'{index[:index.rfind(".")]}.txt'
        stem = index[:index.rfind(".")]
        # 读取图像的宽和高
        try:
            im = cv2.imread(os.path.join(originImagesDir, index))
            height, width, _ = im.shape
        except Exception as e:
            print(f'{os.path.join(originImagesDir, index)} read error.\nerror:{e}')
        # 添加图像的信息
        if not os.path.exists(os.path.join(originLabelsDir, txtFile)):
            # 如没标签,跳过,只保留图片信息.
            continue
        dataset['images'].append({'file_name': index,
                                  'id': stem,
                                  'width': width,
                                  'height': height})
        with open(os.path.join(originLabelsDir, txtFile), 'r') as fr:
            labelList = fr.readlines()
            for label in labelList:
                label = label.strip().split()
                x = float(label[1])
                y = float(label[2])
                w = float(label[3])
                h = float(label[4])

                # convert x,y,w,h to x1,y1,x2,y2
                H, W, _ = im.shape
                x1 = (x - w / 2) * W
                y1 = (y - h / 2) * H
                x2 = (x + w / 2) * W
                y2 = (y + h / 2) * H
                # 标签序号从0开始计算, coco2017数据集标号混乱,不管它了。
                cls_id = int(label[0])
                width = max(0, x2 - x1)
                height = max(0, y2 - y1)
                dataset['annotations'].append({
                    'area': width * height,
                    'bbox': [x1, y1, width, height],
                    'category_id': cls_id,
                    'id': ann_id_cnt,
                    'image_id': stem,
                    'iscrowd': 0,
                    # mask, 矩形是从左上角点按顺时针的四个顶点
                    'segmentation': [[x1, y1, x2, y1, x2, y2, x1, y2]]
                })
                ann_id_cnt += 1

    # 保存结果
    with open(arg.save_path, 'w') as f:
        json.dump(dataset, f)
        print('Save annotation to {}'.format(arg.save_path))


if __name__ == "__main__":
    yolo2coco(arg)

三、参数修改

以faster-rcnn为例,查看文件configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py内容如下:

python 复制代码
_base_ = [
    '../_base_/models/faster-rcnn_r50_fpn.py',
    '../_base_/datasets/coco_detection.py',
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

根据显示内容,修改具体配置。('../base/default_runtime.py'无需修改)

(1)到'../base/models/faster-rcnn_r50_fpn.py'修改

num_classes为自己的实际数据集类别数。

(2)到'../base/datasets/coco_detection.py',修改

python 复制代码
dataset_type = 'CocoDataset'
data_root = 'data/coco/'

由于我使用的是coco数据集格式,只需要改data_root为自己数据集的位置即可。并修改

scale为自己的图像尺寸大小,我的是scale=(640, 640)。接下来根据数据集修改ann_file和data_prefix,根据如上数据集的位置放置,我的修改后该文件完整内容如下:

python 复制代码
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'

# Example to use different file client
# Method 1: simply set the data root and let the file I/O module
# automatically infer from prefix (not support LMDB and Memcache yet)

# data_root = 's3://openmmlab/datasets/detection/coco/'

# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6
# backend_args = dict(
#     backend='petrel',
#     path_mapping=dict({
#         './data/': 's3://openmmlab/datasets/detection/',
#         'data/': 's3://openmmlab/datasets/detection/'
#     }))
backend_args = None

train_pipeline = [
    dict(type='LoadImageFromFile', backend_args=backend_args),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', scale=(640, 640), keep_ratio=True),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PackDetInputs')
]
test_pipeline = [
    dict(type='LoadImageFromFile', backend_args=backend_args),
    dict(type='Resize', scale=(640, 640), keep_ratio=True),
    # If you don't have a gt annotation, delete the pipeline
    dict(type='LoadAnnotations', with_bbox=True),
    dict(
        type='PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor'))
]
train_dataloader = dict(
    batch_size=2,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    batch_sampler=dict(type='AspectRatioBatchSampler'),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='annotations/instances_train2017.json',
        data_prefix=dict(img='train2017/'),
        filter_cfg=dict(filter_empty_gt=True, min_size=32),
        pipeline=train_pipeline,
        backend_args=backend_args))
val_dataloader = dict(
    batch_size=1,
    num_workers=2,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='annotations/instances_val2017.json',
        data_prefix=dict(img='val2017/'),
        test_mode=True,
        pipeline=test_pipeline,
        backend_args=backend_args))
# test_dataloader = val_dataloader

val_evaluator = dict(
    type='CocoMetric',
    ann_file=data_root + 'annotations/instances_val2017.json',
    metric='bbox',
    format_only=False,
    backend_args=backend_args)
# test_evaluator = val_evaluator

# inference on test dataset and
# format the output results for submission.
test_dataloader = dict(
    batch_size=2,
    num_workers=2,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file=data_root + 'annotations/instances_test2017.json',
        data_prefix=dict(img='test2017/'),
        test_mode=True,
        pipeline=test_pipeline))
test_evaluator = dict(
    type='CocoMetric',
    metric='bbox',
    format_only=True,
    ann_file=data_root + 'annotations/instances_test2017.json',
    outfile_prefix='./work_dirs/coco_detection/test')

(3)到'../base/schedules/schedule_1x.py'修改max_epochs为自己设置的最大训练轮次,其他的val_interval、lr、momentum、weight_decay,如果没有特别的要求可不修改。哦,根据自己的电脑情况,别忘了改base_batch_size。该文件整体内容如下:

python 复制代码
# training schedule for 1x
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100, val_interval=10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# learning rate
param_scheduler = [
    dict(
        type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
    dict(
        type='MultiStepLR',
        begin=0,
        end=100,
        by_epoch=True,
        milestones=[67, 92],
        gamma=0.1)
]

# optimizer
optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='SGD', lr=0.01, momentum=0.937, weight_decay=0.0001))

# Default setting for scaling LR automatically
#   - `enable` means enable scaling LR automatically
#       or not by default.
#   - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=16)

四、命令训练

使用如下命令训练:

python 复制代码
python  tools/train.py configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py

如果训练中断,使用resume继续(注意'--work-dir work_dirs/faster-rcnn_r50_fpn_1x_coco'是训练结果输出的位置,epoch_21.pth是上次训练中断后,输出的最后一个pth,根据自己的实际情况修改):

python 复制代码
python tools/train.py configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py  --work-dir work_dirs/faster-rcnn_r50_fpn_1x_coco --resume work_dirs/faster-rcnn_r50_fpn_1x_coco/epoch_21.pth

五、转为YOLO格式的评价指标

(1)找出最佳epoch

python 复制代码
import os
import subprocess
import pickle
import numpy as np
import json
from prettytable import PrettyTable
from tqdm import tqdm

# 设置工作目录和模型文件路径
work_dir = "work_dirs/faster-rcnn_r50_fpn_1x_coco"
config_file = "configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py"  # 你的config文件路径

# 存放模型权重文件(epoch_1.pth 到 epoch_100.pth)
checkpoint_dir = os.path.join(work_dir, "")  # 假设检查点文件在 'checkpoints' 子文件夹下

# 遍历模型权重文件(epoch_1.pth, epoch_2.pth, ..., epoch_100.pth)
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
checkpoint_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))  # 按照 epoch 数字排序

# 用于存储评估结果
results = []

# 循环遍历每个模型文件
for checkpoint_file in tqdm(checkpoint_files, desc="Evaluating"):
    checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)

    # 设置输出 pkl 文件的路径
    output_pkl = f"res_{checkpoint_file.split('.')[0]}.pkl"

    # 运行 test.py 脚本进行模型评估
    command = [
        "python", "tools/test.py", config_file, checkpoint_path,
        "--out", output_pkl  # 输出 pkl 文件
    ]

    # 使用 subprocess 运行命令并捕获输出
    result = subprocess.run(command, capture_output=True, text=True)

    # 调试输出
    print(f"Evaluating {checkpoint_file}...")
    print(result.stdout)

    # 假设输出的评估结果包含了 mAP
    # 解析 mAP,假设它包含在 stdout 中,例如 "bbox_mAP: 0.45"
    for line in result.stdout.splitlines():
        if "bbox_mAP" in line:
            try:
                # 提取 mAP 分数
                bbox_mAP = float(line.split(":")[-1].strip())
                epoch = int(checkpoint_file.split("_")[1].split(".")[0])  # 获取 epoch 数字
                results.append((epoch, bbox_mAP, output_pkl))  # 存储结果 (epoch, mAP, pkl文件)
                break
            except ValueError:
                print(f"Error parsing bbox_mAP for {checkpoint_file}: {line}")
                continue

# 计算并输出最佳 epoch 和 mAP
if results:
    # 根据 mAP 找到最佳的 epoch
    best_epoch, best_mAP, best_pkl = max(results, key=lambda x: x[1])
    print(f"Best Epoch: {best_epoch}, Best bbox_mAP: {best_mAP:.4f}, Best Output pkl: {best_pkl}")
else:
    print("No valid results found. Please check the log outputs.")

# 计算所有 epoch 的 mAP 值
table = PrettyTable()
table.title = f"Evaluation Metrics"
table.field_names = ["Epoch", "bbox_mAP", "Output pkl"]

# 添加每个 epoch 的评估结果
for epoch, mAP, pkl_file in results:
    table.add_row([epoch, f"{mAP:.4f}", pkl_file])

print(table)

(2)根据最佳epoch生成pkl文件

python 复制代码
python tools/test.py work_dirs/faster-rcnn_r50_fpn_1x_coco/faster-rcnn_r50_fpn_1x_coco.py work_dirs/faster-rcnn_r50_fpn_1x_coco/best_coco_bbox_mAP_epoch_90.pth --out res90.pkl

(3)根据pkl文件输出对比参数,修改内容在'def parse_opt():',将内容改为实际的地址名称。

完整代码如下:

python 复制代码
import os, torch, cv2, math, tqdm, time, shutil, argparse, json, pickle
import numpy as np
from prettytable import PrettyTable


def clip_boxes(boxes, shape):
    # Clip boxes (xyxy) to image shape (height, width)
    if isinstance(boxes, torch.Tensor):  # faster individually
        boxes[..., 0].clamp_(0, shape[1])  # x1
        boxes[..., 1].clamp_(0, shape[0])  # y1
        boxes[..., 2].clamp_(0, shape[1])  # x2
        boxes[..., 3].clamp_(0, shape[0])  # y2
    else:  # np.array (faster grouped)
        boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1])  # x1, x2
        boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0])  # y1, y2


def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
    # Rescale boxes (xyxy) from img1_shape to img0_shape
    if ratio_pad is None:  # calculate from img0_shape
        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
        pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
    else:
        gain = ratio_pad[0][0]
        pad = ratio_pad[1]

    boxes[..., [0, 2]] -= pad[0]  # x padding
    boxes[..., [1, 3]] -= pad[1]  # y padding
    boxes[..., :4] /= gain
    clip_boxes(boxes, img0_shape)
    return boxes


def box_iou(box1, box2, eps=1e-7):
    """
    Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
    Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
    Args:
        box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
        box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
    Returns:
        (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
    """

    # NOTE: Need .float() to get accurate iou values
    # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
    (a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2)
    inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)

    # IoU = inter / (area1 + area2 - inter)
    return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)


def process_batch(detections, labels, iouv):
    """
    Return correct prediction matrix
    Arguments:
        detections (array[N, 6]), x1, y1, x2, y2, conf, class
        labels (array[M, 5]), class, x1, y1, x2, y2
    Returns:
        correct (array[N, 10]), for 10 IoU levels
    """
    correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool)
    iou = box_iou(labels[:, 1:], detections[:, :4])
    correct_class = labels[:, 0:1] == detections[:, 5]
    for i in range(len(iouv)):
        x = torch.where((iou >= iouv[i]) & correct_class)  # IoU > threshold and classes match
        if x[0].shape[0]:
            matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()  # [label, detect, iou]
            if x[0].shape[0] > 1:
                matches = matches[matches[:, 2].argsort()[::-1]]
                matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
                # matches = matches[matches[:, 2].argsort()[::-1]]
                matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
            correct[matches[:, 1].astype(int), i] = True
    return torch.tensor(correct, dtype=torch.bool, device=iouv.device)


def smooth(y, f=0.05):
    # Box filter of fraction f
    nf = round(len(y) * f * 2) // 2 + 1  # number of filter elements (must be odd)
    p = np.ones(nf // 2)  # ones padding
    yp = np.concatenate((p * y[0], y, p * y[-1]), 0)  # y padded
    return np.convolve(yp, np.ones(nf) / nf, mode='valid')  # y-smoothed


def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16, prefix=''):
    """ Compute the average precision, given the recall and precision curves.
    Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
    # Arguments
        tp:  True positives (nparray, nx1 or nx10).
        conf:  Objectness value from 0-1 (nparray).
        pred_cls:  Predicted object classes (nparray).
        target_cls:  True object classes (nparray).
        plot:  Plot precision-recall curve at mAP@0.5
        save_dir:  Plot save directory
    # Returns
        The average precision as computed in py-faster-rcnn.
    """

    # Sort by objectness
    i = np.argsort(-conf)
    tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]

    # Find unique classes
    unique_classes, nt = np.unique(target_cls, return_counts=True)
    nc = unique_classes.shape[0]  # number of classes, number of detections

    # Create Precision-Recall curve and compute AP for each class
    px, py = np.linspace(0, 1, 1000), []  # for plotting
    ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
    for ci, c in enumerate(unique_classes):
        i = pred_cls == c
        n_l = nt[ci]  # number of labels
        n_p = i.sum()  # number of predictions
        if n_p == 0 or n_l == 0:
            continue

        # Accumulate FPs and TPs
        fpc = (1 - tp[i]).cumsum(0)
        tpc = tp[i].cumsum(0)

        # Recall
        recall = tpc / (n_l + eps)  # recall curve
        r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0)  # negative x, xp because xp decreases

        # Precision
        precision = tpc / (tpc + fpc)  # precision curve
        p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1)  # p at pr_score

        # AP from recall-precision curve
        for j in range(tp.shape[1]):
            ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
            if plot and j == 0:
                py.append(np.interp(px, mrec, mpre))  # precision at mAP@0.5

    # Compute F1 (harmonic mean of precision and recall)
    f1 = 2 * p * r / (p + r + eps)

    i = smooth(f1.mean(0), 0.1).argmax()  # max F1 index
    p, r, f1 = p[:, i], r[:, i], f1[:, i]
    tp = (r * nt).round()  # true positives
    fp = (tp / (p + eps) - tp).round()  # false positives
    return tp, fp, p, r, f1, ap, unique_classes.astype(int)


def compute_ap(recall, precision):
    """ Compute the average precision, given the recall and precision curves
    # Arguments
        recall:    The recall curve (list)
        precision: The precision curve (list)
    # Returns
        Average precision, precision curve, recall curve
    """

    # Append sentinel values to beginning and end
    mrec = np.concatenate(([0.0], recall, [1.0]))
    mpre = np.concatenate(([1.0], precision, [0.0]))

    # Compute the precision envelope
    mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))

    # Integrate area under curve
    method = 'interp'  # methods: 'continuous', 'interp'
    if method == 'interp':
        x = np.linspace(0, 1, 101)  # 101-point interp (COCO)
        ap = np.trapz(np.interp(x, mrec, mpre), x)  # integrate
    else:  # 'continuous'
        i = np.where(mrec[1:] != mrec[:-1])[0]  # points where x axis (recall) changes
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])  # area under curve

    return ap, mpre, mrec


def parse_opt():
    parser = argparse.ArgumentParser()
    parser.add_argument('--label_coco', type=str, default='E:/mmde/mmdetection-3.0.0/mmdetection-3.0.0/test.json',
                        help='label coco path')
    # parser.add_argument('--pred_coco', type=str, default='runs/val/exp/predictions.json', help='pred coco path')
    parser.add_argument('--pred_coco', type=str, default='E:/mmde/mmdetection-3.0.0/mmdetection-3.0.0/res90.pkl', help='pred coco path')
    parser.add_argument('--iou', type=float, default=0.7, help='iou threshold')
    parser.add_argument('--conf', type=float, default=0.001, help='conf threshold')
    opt = parser.parse_known_args()[0]
    return opt


if __name__ == '__main__':
    opt = parse_opt()

    iouv = torch.linspace(0.5, 0.95, 10)  # iou vector for mAP@0.5:0.95
    niou = iouv.numel()
    stats = []

    label_coco_json_path, pred_coco_json_path = opt.label_coco, opt.pred_coco
    with open(label_coco_json_path) as f:
        label = json.load(f)

    classes = []
    for data in label['categories']:
        classes.append(data['name'])

    image_id_hw_dict = {}
    for data in label['images']:
        image_id_hw_dict[data['id']] = [data['height'], data['width']]

    label_id_dict = {}
    for data in tqdm.tqdm(label['annotations'], desc='Process label...'):
        if data['image_id'] not in label_id_dict:
            label_id_dict[data['image_id']] = []

        category_id = data['category_id']
        x_min, y_min, w, h = data['bbox'][0], data['bbox'][1], data['bbox'][2], data['bbox'][3]
        x_max, y_max = x_min + w, y_min + h
        label_id_dict[data['image_id']].append(np.array([int(category_id), x_min, y_min, x_max, y_max]))

    if pred_coco_json_path.endswith('json'):
        with open(pred_coco_json_path) as f:
            pred = json.load(f)
        pred_id_dict = {}
        for data in tqdm.tqdm(pred, desc='Process pred...'):
            if data['image_id'] not in pred_id_dict:
                pred_id_dict[data['image_id']] = []

            score = data['score']
            category_id = data['category_id']
            x_min, y_min, w, h = data['bbox'][0], data['bbox'][1], data['bbox'][2], data['bbox'][3]
            x_max, y_max = x_min + w, y_min + h

            pred_id_dict[data['image_id']].append(
                np.array([x_min, y_min, x_max, y_max, float(score), int(category_id)]))
    else:
        with open(pred_coco_json_path, 'rb') as f:
            pred = pickle.load(f)
        pred_id_dict = {}
        for data in tqdm.tqdm(pred, desc='Process pred...'):
            image_id = os.path.splitext(os.path.basename(data['img_path']))[0]
            if image_id not in pred_id_dict:
                pred_id_dict[image_id] = []

            for i in range(data['pred_instances']['labels'].size(0)):
                score = data['pred_instances']['scores'][i]
                category_id = data['pred_instances']['labels'][i]
                bboxes = data['pred_instances']['bboxes'][i]

                x_min, y_min, x_max, y_max = bboxes.cpu().detach().numpy()
                # x_min, x_max = x_min / data['scale_factor'][0], x_max / data['scale_factor'][0]
                # y_min, y_max = y_min / data['scale_factor'][1], y_max / data['scale_factor'][1]

                pred_id_dict[image_id].append(np.array([x_min, y_min, x_max, y_max, float(score), int(category_id)]))

    for idx, image_id in enumerate(tqdm.tqdm(list(image_id_hw_dict.keys()), desc="Cal mAP...")):
        label = np.array(label_id_dict[image_id])

        if image_id not in pred_id_dict:
            pred = np.empty((0, 6))
        else:
            pred = torch.from_numpy(np.array(pred_id_dict[image_id]))

        nl, npr = label.shape[0], pred.shape[0]
        correct = torch.zeros(npr, niou, dtype=torch.bool)
        if npr == 0:
            if nl:
                stats.append((correct, *torch.zeros((2, 0)), torch.from_numpy(label[:, 0])))
            continue

        if nl:
            correct = process_batch(pred, torch.from_numpy(label), iouv)
        stats.append((correct, pred[:, 4], pred[:, 5], torch.from_numpy(label[:, 0])))

    stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)]
    tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
    print(f'precision:{p}')
    print(f'recall:{r}')
    print(f'mAP@0.5:{ap[:, 0]}')

    table = PrettyTable()
    table.title = f"Metrice"
    table.field_names = ["Classes", 'Precision', 'Recall', 'mAP50', 'mAP50-95']
    table.add_row(['all', f'{np.mean(p):.3f}', f'{np.mean(r):.3f}', f'{np.mean(ap[:, 0]):.3f}', f'{np.mean(ap):.3f}'])
    for cls_idx, classes in enumerate(classes):
        table.add_row([classes, f'{p[cls_idx]:.3f}', f'{r[cls_idx]:.3f}', f'{ap[cls_idx, 0]:.3f}',
                       f'{ap[cls_idx, :].mean():.3f}'])
    print(table)

六、输出结果

相关推荐
历程里程碑5 分钟前
哈希3 : 最长连续序列
java·数据结构·c++·python·算法·leetcode·tornado
火云洞红孩儿10 分钟前
2026年,用PyMe可视化编程重塑Python学习
开发语言·python·学习
2401_8414956412 分钟前
【LeetCode刷题】两两交换链表中的节点
数据结构·python·算法·leetcode·链表·指针·迭代法
幻云201012 分钟前
Next.js 之道:从入门到精通
前端·javascript·vue.js·人工智能·python
SunnyDays101116 分钟前
使用 Python 自动查找并高亮 Word 文档中的文本
经验分享·python·高亮word文字·查找word文档中的文字
深蓝电商API21 分钟前
Selenium处理弹窗、警报和验证码识别
爬虫·python·selenium
栗少22 分钟前
英语逻辑词
学习
深蓝电商API26 分钟前
Selenium模拟滚动加载无限下拉页面
爬虫·python·selenium
小王子102430 分钟前
Redis Queue 安装与使用
redis·python·任务队列·rq·redis queue
人工智能AI技术32 分钟前
【Agent从入门到实践】26 使用Chroma搭建本地向量库,实现Agent的短期记忆
人工智能·python