python
复制代码
import os
import shutil
import json
import time
import torch
import cv2
from PIL import Image
import numpy as np
from ultralytics import YOLO
from torchvision.ops import batched_nms
class YOLO_Class():
def __init__(self, model_path, device="cuda:0"):
self.model = YOLO(model_path, verbose=False).to(device)
self.last_log_time = time.time()
self.frame_counter = 0
self.conf_threshold = 0.5
def iou(self, box1, box2):
"""计算两个矩形框的 IoU"""
x1_min, y1_min, x1_max, y1_max = box1
x2_min, y2_min, x2_max, y2_max = box2
inter_xmin = max(x1_min, x2_min)
inter_ymin = max(y1_min, y2_min)
inter_xmax = min(x1_max, x2_max)
inter_ymax = min(y1_max, y2_max)
inter_w = max(0, inter_xmax - inter_xmin)
inter_h = max(0, inter_ymax - inter_ymin)
inter_area = inter_w * inter_h
area1 = (x1_max - x1_min) * (y1_max - y1_min)
area2 = (x2_max - x2_min) * (y2_max - y2_min)
union_area = area1 + area2 - inter_area
if union_area == 0:
return 0
return inter_area / union_area
def get_normalized_box(self, points):
"""将LabelMe点转换为 [x_min, y_min, x_max, y_max] 格式"""
x_coords = [point[0] for point in points]
y_coords = [point[1] for point in points]
x_min = min(x_coords)
y_min = min(y_coords)
x_max = max(x_coords)
y_max = max(y_coords)
return [x_min, y_min, x_max, y_max]
def load_labelme_annotations(self, json_path):
"""加载LabelMe标注"""
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
annotations = []
for shape in data.get('shapes', []):
box = self.get_normalized_box(shape['points'])
annotations.append({
'label': shape['label'],
'box': box,
'points': shape['points']
})
return annotations
def detect_img(self, image):
"""YOLO推理"""
start_time = time.time()
with torch.no_grad():
results = self.model(image, verbose=False)
cls = results[0].boxes.cls.int().cpu()
labels = results[0].boxes.cls
boxes = results[0].boxes.xyxy
scores = results[0].boxes.conf
indices = torch.where(scores > 0.3)[0]
labels = labels[indices]
boxes = boxes[indices]
scores = scores[indices]
if len(boxes) == 0:
return []
# 使用PyTorch的batched_nms
try:
# PyTorch 1.10.0及以上版本
keep_indices = batched_nms(boxes, scores, labels, 0.5)
except ImportError:
# 兼容旧版本
keep_indices = []
unique_classes = torch.unique(labels)
for cls_id in unique_classes:
cls_mask = (labels == cls_id)
cls_boxes = boxes[cls_mask]
cls_scores = scores[cls_mask]
if len(cls_boxes) == 0:
continue
from torchvision.ops import nms
cls_keep = nms(cls_boxes, cls_scores, 0.2)
global_cls_indices = torch.where(cls_mask)[0]
keep_indices.extend(global_cls_indices[cls_keep].tolist())
keep_indices = torch.tensor(keep_indices, device=boxes.device)
# 根据保留的索引筛选结果
labels = labels[keep_indices]
boxes = boxes[keep_indices]
scores = scores[keep_indices]
detections = []
for i in range(len(boxes)):
detections.append({
'label': self.model.names[int(labels[i])],
'box': boxes[i].cpu().numpy().tolist(),
'score': scores[i].cpu().numpy().item()
})
return detections
if 0:
detections = []
for i in range(len(boxes)):
detections.append({
'label': self.model.names[int(labels[i])],
'box': boxes[i].cpu().numpy().tolist(),
'score': scores[i].cpu().numpy().item()
})
return detections
def draw_detections_on_image(self, image_path, detections, output_path):
"""在图片上绘制YOLO检测结果(绿色框)"""
# 读取图片
image = cv2.imread(image_path)
if image is None:
print(f"无法读取图片: {image_path}")
return False
# 绘制每个检测框
for det in detections:
box = det['box']
label = det['label']
score = det['score']
# 转换为整数坐标
x1, y1, x2, y2 = map(int, box)
# 绘制绿色矩形框
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# 添加标签和置信度
label_text = f"{label}: {score:.2f}"
label_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
# 绘制标签背景
cv2.rectangle(image, (x1, y1 - label_size[1] - 10),
(x1 + label_size[0], y1), (0, 255, 0), -1)
# 绘制标签文字
cv2.putText(image, label_text, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
# 保存带检测结果的图片
cv2.imwrite(output_path, image)
return True
def compare_detections(self, image_path, output_base_dir, iou_threshold=0.5):
"""对比YOLO检测和LabelMe标注"""
# 获取对应的JSON文件路径
base_name = os.path.splitext(image_path)[0]
json_path = base_name + '.json'
if not os.path.exists(json_path):
print(f"未找到标注文件: {json_path}")
return
# 加载图像和标注
image = Image.open(image_path)
yolo_detections = self.detect_img(image)
labelme_annotations = self.load_labelme_annotations(json_path)
# 匹配检测框和标注框
matched_pairs = []
unmatched_detections = yolo_detections.copy()
unmatched_annotations = labelme_annotations.copy()
# 匹配IOU大于阈值的框
for det in yolo_detections:
for ann in labelme_annotations:
iou_score = self.iou(det['box'], ann['box'])
if iou_score >= iou_threshold and det['label'] == ann['label']:
matched_pairs.append((det, ann, iou_score))
if det in unmatched_detections:
unmatched_detections.remove(det)
if ann in unmatched_annotations:
unmatched_annotations.remove(ann)
break
# 分类判断
yolo_count = len(yolo_detections)
labelme_count = len(labelme_annotations)
matched_count = len(matched_pairs)
category = ""
if yolo_count > labelme_count:
category = "more" # 多检
elif yolo_count < labelme_count:
category = "miss" # 漏检
else:
category = "low_iou" # 数量相同但IOU低
# 如果有IOU小于0.5的匹配对,也归为低IOU类别
low_iou_exists = any(iou_score < 0.5 for _, _, iou_score in matched_pairs)
if low_iou_exists:
category = "low_iou"
# 保存到对应目录
output_dir = os.path.join(output_base_dir, category)
os.makedirs(output_dir, exist_ok=True)
# 复制原图片和标注文件
shutil.copy2(image_path, output_dir)
shutil.copy2(json_path, output_dir)
# 生成带检测结果的图片
image_filename = os.path.basename(image_path)
result_image_path = os.path.join(output_dir, f"{image_filename}")
self.draw_detections_on_image(image_path, yolo_detections, result_image_path)
# 生成对比信息文件
info_file = os.path.join(output_dir, f"{os.path.basename(base_name)}_info.txt")
with open(info_file, 'w', encoding='utf-8') as f:
f.write(f"图片: {os.path.basename(image_path)}\n")
f.write(f"分类: {category}\n")
f.write(f"YOLO检测数: {yolo_count}\n")
f.write(f"LabelMe标注数: {labelme_count}\n")
f.write(f"匹配数: {matched_count}\n")
f.write(f"多检数: {len(unmatched_detections)}\n")
f.write(f"漏检数: {len(unmatched_annotations)}\n")
f.write("\n匹配详情:\n")
for det, ann, iou_score in matched_pairs:
f.write(f" {det['label']}: IOU={iou_score:.3f}\n")
if unmatched_detections:
f.write("\n多检框:\n")
for det in unmatched_detections:
f.write(f" {det['label']}: score={det['score']:.3f}\n")
if unmatched_annotations:
f.write("\n漏检框:\n")
for ann in unmatched_annotations:
f.write(f" {ann['label']}\n")
print(f"处理完成: {os.path.basename(image_path)} -> {category}")
def evaluate_dataset(self, image_dir, output_base_dir, iou_threshold=0.5):
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
img_files = ['%s/%s' % (i[0].replace("\\", "/"), j) for i in os.walk(image_dir) for j in i[-1] if
j.lower().endswith('.jpg')]
for img_path in img_files:
self.compare_detections(img_path, output_base_dir, iou_threshold)
# 使用示例
if __name__ == '__main__':
yolo = YOLO_Class(r'D:\project\yolov12-main_new\weights\chan4\weights\best.pt', device="cuda:0")
image_dir = r"D:\data\det_chantu\val_penzi" # 图片目录
output_base_dir = r"D:\data\det_chantu\evaluation_results2" # 输出目录
yolo.evaluate_dataset(image_dir, output_base_dir, iou_threshold=0.5)