yolo 获取异常样本 yolo 异常

yolo 获取结果异常脚本:

python 复制代码
import os
import shutil
import json
import time

import torch
import cv2
from PIL import Image
import numpy as np

from ultralytics import YOLO
from torchvision.ops import batched_nms

class YOLO_Class():
    def __init__(self, model_path, device="cuda:0"):
        self.model = YOLO(model_path, verbose=False).to(device)
        self.last_log_time = time.time()
        self.frame_counter = 0
        self.conf_threshold = 0.5

    def iou(self, box1, box2):
        """计算两个矩形框的 IoU"""
        x1_min, y1_min, x1_max, y1_max = box1
        x2_min, y2_min, x2_max, y2_max = box2

        inter_xmin = max(x1_min, x2_min)
        inter_ymin = max(y1_min, y2_min)
        inter_xmax = min(x1_max, x2_max)
        inter_ymax = min(y1_max, y2_max)

        inter_w = max(0, inter_xmax - inter_xmin)
        inter_h = max(0, inter_ymax - inter_ymin)
        inter_area = inter_w * inter_h

        area1 = (x1_max - x1_min) * (y1_max - y1_min)
        area2 = (x2_max - x2_min) * (y2_max - y2_min)
        union_area = area1 + area2 - inter_area

        if union_area == 0:
            return 0
        return inter_area / union_area

    def get_normalized_box(self, points):
        """将LabelMe点转换为 [x_min, y_min, x_max, y_max] 格式"""
        x_coords = [point[0] for point in points]
        y_coords = [point[1] for point in points]
        x_min = min(x_coords)
        y_min = min(y_coords)
        x_max = max(x_coords)
        y_max = max(y_coords)
        return [x_min, y_min, x_max, y_max]

    def load_labelme_annotations(self, json_path):
        """加载LabelMe标注"""
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        annotations = []
        for shape in data.get('shapes', []):
            box = self.get_normalized_box(shape['points'])
            annotations.append({
                'label': shape['label'],
                'box': box,
                'points': shape['points']
            })
        return annotations

    def detect_img(self, image):
        """YOLO推理"""
        start_time = time.time()
        with torch.no_grad():
            results = self.model(image, verbose=False)

        cls = results[0].boxes.cls.int().cpu()
        labels = results[0].boxes.cls
        boxes = results[0].boxes.xyxy
        scores = results[0].boxes.conf
        indices = torch.where(scores > 0.3)[0]

        labels = labels[indices]
        boxes = boxes[indices]
        scores = scores[indices]

        if len(boxes) == 0:
            return []

            # 使用PyTorch的batched_nms
        try:
            # PyTorch 1.10.0及以上版本
            keep_indices = batched_nms(boxes, scores, labels, 0.5)
        except ImportError:
            # 兼容旧版本
            keep_indices = []
            unique_classes = torch.unique(labels)
            for cls_id in unique_classes:
                cls_mask = (labels == cls_id)
                cls_boxes = boxes[cls_mask]
                cls_scores = scores[cls_mask]

                if len(cls_boxes) == 0:
                    continue

                from torchvision.ops import nms
                cls_keep = nms(cls_boxes, cls_scores, 0.2)
                global_cls_indices = torch.where(cls_mask)[0]
                keep_indices.extend(global_cls_indices[cls_keep].tolist())

            keep_indices = torch.tensor(keep_indices, device=boxes.device)

            # 根据保留的索引筛选结果
        labels = labels[keep_indices]
        boxes = boxes[keep_indices]
        scores = scores[keep_indices]

        detections = []
        for i in range(len(boxes)):
            detections.append({
                'label': self.model.names[int(labels[i])],
                'box': boxes[i].cpu().numpy().tolist(),
                'score': scores[i].cpu().numpy().item()
            })

        return detections

        if 0:
            detections = []
            for i in range(len(boxes)):
                detections.append({
                    'label': self.model.names[int(labels[i])],
                    'box': boxes[i].cpu().numpy().tolist(),
                    'score': scores[i].cpu().numpy().item()
                })

            return detections

    def draw_detections_on_image(self, image_path, detections, output_path):
        """在图片上绘制YOLO检测结果(绿色框)"""
        # 读取图片
        image = cv2.imread(image_path)
        if image is None:
            print(f"无法读取图片: {image_path}")
            return False

        # 绘制每个检测框
        for det in detections:
            box = det['box']
            label = det['label']
            score = det['score']

            # 转换为整数坐标
            x1, y1, x2, y2 = map(int, box)

            # 绘制绿色矩形框
            cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)

            # 添加标签和置信度
            label_text = f"{label}: {score:.2f}"
            label_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]

            # 绘制标签背景
            cv2.rectangle(image, (x1, y1 - label_size[1] - 10),
                          (x1 + label_size[0], y1), (0, 255, 0), -1)

            # 绘制标签文字
            cv2.putText(image, label_text, (x1, y1 - 5),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)

        # 保存带检测结果的图片
        cv2.imwrite(output_path, image)
        return True

    def compare_detections(self, image_path, output_base_dir, iou_threshold=0.5):
        """对比YOLO检测和LabelMe标注"""
        # 获取对应的JSON文件路径
        base_name = os.path.splitext(image_path)[0]
        json_path = base_name + '.json'

        if not os.path.exists(json_path):
            print(f"未找到标注文件: {json_path}")
            return

        # 加载图像和标注
        image = Image.open(image_path)
        yolo_detections = self.detect_img(image)
        labelme_annotations = self.load_labelme_annotations(json_path)

        # 匹配检测框和标注框
        matched_pairs = []
        unmatched_detections = yolo_detections.copy()
        unmatched_annotations = labelme_annotations.copy()

        # 匹配IOU大于阈值的框
        for det in yolo_detections:
            for ann in labelme_annotations:
                iou_score = self.iou(det['box'], ann['box'])
                if iou_score >= iou_threshold and det['label'] == ann['label']:
                    matched_pairs.append((det, ann, iou_score))
                    if det in unmatched_detections:
                        unmatched_detections.remove(det)
                    if ann in unmatched_annotations:
                        unmatched_annotations.remove(ann)
                    break

        # 分类判断
        yolo_count = len(yolo_detections)
        labelme_count = len(labelme_annotations)
        matched_count = len(matched_pairs)

        category = ""
        if yolo_count > labelme_count:
            category = "more"  # 多检
        elif yolo_count < labelme_count:
            category = "miss"  # 漏检
        else:
            category = "low_iou"  # 数量相同但IOU低

        # 如果有IOU小于0.5的匹配对,也归为低IOU类别
        low_iou_exists = any(iou_score < 0.5 for _, _, iou_score in matched_pairs)
        if low_iou_exists:
            category = "low_iou"

        # 保存到对应目录
        output_dir = os.path.join(output_base_dir, category)
        os.makedirs(output_dir, exist_ok=True)

        # 复制原图片和标注文件
        shutil.copy2(image_path, output_dir)
        shutil.copy2(json_path, output_dir)

        # 生成带检测结果的图片
        image_filename = os.path.basename(image_path)
        result_image_path = os.path.join(output_dir, f"{image_filename}")
        self.draw_detections_on_image(image_path, yolo_detections, result_image_path)

        # 生成对比信息文件
        info_file = os.path.join(output_dir, f"{os.path.basename(base_name)}_info.txt")
        with open(info_file, 'w', encoding='utf-8') as f:
            f.write(f"图片: {os.path.basename(image_path)}\n")
            f.write(f"分类: {category}\n")
            f.write(f"YOLO检测数: {yolo_count}\n")
            f.write(f"LabelMe标注数: {labelme_count}\n")
            f.write(f"匹配数: {matched_count}\n")
            f.write(f"多检数: {len(unmatched_detections)}\n")
            f.write(f"漏检数: {len(unmatched_annotations)}\n")
            f.write("\n匹配详情:\n")
            for det, ann, iou_score in matched_pairs:
                f.write(f"  {det['label']}: IOU={iou_score:.3f}\n")
            if unmatched_detections:
                f.write("\n多检框:\n")
                for det in unmatched_detections:
                    f.write(f"  {det['label']}: score={det['score']:.3f}\n")
            if unmatched_annotations:
                f.write("\n漏检框:\n")
                for ann in unmatched_annotations:
                    f.write(f"  {ann['label']}\n")

        print(f"处理完成: {os.path.basename(image_path)} -> {category}")

    def evaluate_dataset(self, image_dir, output_base_dir, iou_threshold=0.5):
        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']

        img_files = ['%s/%s' % (i[0].replace("\\", "/"), j) for i in os.walk(image_dir) for j in i[-1] if
                     j.lower().endswith('.jpg')]

        for img_path in img_files:
            self.compare_detections(img_path, output_base_dir, iou_threshold)


# 使用示例
if __name__ == '__main__':
    yolo = YOLO_Class(r'D:\project\yolov12-main_new\weights\chan4\weights\best.pt', device="cuda:0")

    image_dir = r"D:\data\det_chantu\val_penzi"  # 图片目录
    output_base_dir = r"D:\data\det_chantu\evaluation_results2"  # 输出目录

    yolo.evaluate_dataset(image_dir, output_base_dir, iou_threshold=0.5)
相关推荐
散峰而望2 小时前
C++入门(二) (算法竞赛)
开发语言·c++·算法·github
程序员爱钓鱼3 小时前
Python编程实战 面向对象与进阶语法 迭代器与生成器
后端·python·ipython
程序员爱钓鱼3 小时前
Python编程实战 面向对象与进阶语法 JSON数据读写
后端·python·ipython
沐知全栈开发3 小时前
CSS Float(浮动)详解
开发语言
Cx330❀3 小时前
《C++ 搜索二叉树》深入理解 C++ 搜索二叉树:特性、实现与应用
java·开发语言·数据结构·c++·算法·面试
TH88863 小时前
一体化负氧离子监测站:实时、精准监测空气中负氧离子浓度及其他环境参数
python
阿猿收手吧!3 小时前
【C语言】localtime和localtime_r;strftime和strftime_l
linux·c语言·开发语言
苏打水com3 小时前
0基础学前端:100天拿offer实战课(第3天)—— CSS基础美化:给网页“精装修”的5大核心技巧
人工智能·python·tensorflow
不染尘.3 小时前
2025_11_5_刷题
开发语言·c++·vscode·算法·贪心算法·动态规划