YOLOv5 分类模型 Top 1和Top 5 指标实现

YOLOv5 分类模型 Top 1和Top 5 指标实现

flyfish

复制代码
import time
from models.common import DetectMultiBackend
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np

import torch
from utils.augmentations import classify_transforms


class DatasetFolder:

    def __init__(
        self,
        root: str,

    ) -> None:
        self.root = root
        classes, class_to_idx = self.find_classes(self.root)
        samples = self.make_dataset(self.root, class_to_idx)

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Optional[Dict[str, int]] = None,

    ) -> List[Tuple[str, int]]:

        directory = os.path.expanduser(directory)

        if class_to_idx is None:
            _, class_to_idx = self.find_classes(directory)
        elif not class_to_idx:
            raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

        instances = []
        available_classes = set()
        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    if 1:  # 验证:
                        item = path, class_index
                        instances.append(item)

                        if target_class not in available_classes:
                            available_classes.add(target_class)

        empty_classes = set(class_to_idx.keys()) - available_classes
        if empty_classes:
            msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "

        return instances

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:

        classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
        if not classes:
            raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def __getitem__(self, index: int) -> Tuple[Any, Any]:

        path, target = self.samples[index]
        sample = self.loader(path)

        return sample, target

    def __len__(self) -> int:
        return len(self.samples)

    def loader(self, path):
        print("path:", path)
        img = cv2.imread(path)  # BGR HWC
        return img


def time_sync():
    return time.time()


dataset = DatasetFolder(root="/media/flyfish/test/val")

# image, label=dataset[7]
# print(image.shape)
#
weights = "/media/flyfish/yolov5-6.2/classes10.pt"
device = "cpu"
model = DetectMultiBackend(weights, device=device, dnn=False, fp16=False)
model.eval()

transforms = classify_transforms(224)

pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
# current batch size =1
for i, (images, labels) in enumerate(dataset):
    print("i:", i)
    print(images.shape, labels)
    im = cv2.cvtColor(images, cv2.COLOR_BGR2RGB)
    im = transforms(im)
    images = im.unsqueeze(0).to("cpu")
 
    print(images.shape)


        
    t1 = time_sync()
    images = images.to(device, non_blocking=True)
    t2 = time_sync()
    # dt[0] += t2 - t1

    y = model(images)
    y=y.numpy()
   
    print("y:", y)
    t3 = time_sync()
    # dt[1] += t3 - t2

    tmp1=y.argsort()[:,::-1][:, :5]
   
    print("tmp1:", tmp1)
    pred.append(tmp1)

    print("labels:", labels)

    
    targets.append(labels)

    print("for pred:", pred)  # list
    print("for targets:", targets)  # list

    # dt[2] += time_sync() - t3


pred, targets = np.concatenate(pred), np.array(targets)
print("pred:", pred)
print("pred:", pred.shape)
print("targets:", targets)
print("targets:", targets.shape)
correct = ((targets[:, None] == pred)).astype(np.float32)
print("correct:", correct.shape)
print("correct:", correct)
acc = np.stack((correct[:, 0], correct.max(1)), axis=1)  # (top1, top5) accuracy
print("acc:", acc.shape)
print("acc:", acc)
top = acc.mean(0)
print("top1:", top[0])
print("top5:", top[1])

输出

复制代码
pred: [[7 4 0 5 9]
 [9 2 4 6 7]
 [8 9 6 2 1]
 [8 9 6 2 7]
 [9 2 4 6 3]
 [6 7 1 2 9]
 [4 2 1 8 9]
 [6 8 9 5 2]
 [8 7 4 2 6]
 [9 8 2 6 4]
 [2 9 8 0 6]
 [7 4 8 6 3]]
pred: (12, 5)
targets: [0 0 0 0 1 1 1 1 2 2 2 2]
targets: (12,)
correct: (12, 5)
correct: [[          0           0           1           0           0]
 [          0           0           0           0           0]
 [          0           0           0           0           0]
 [          0           0           0           0           0]
 [          0           0           0           0           0]
 [          0           0           1           0           0]
 [          0           0           1           0           0]
 [          0           0           0           0           0]
 [          0           0           0           1           0]
 [          0           0           1           0           0]
 [          1           0           0           0           0]
 [          0           0           0           0           0]]
acc: (12, 2)
acc: [[          0           1]
 [          0           0]
 [          0           0]
 [          0           0]
 [          0           0]
 [          0           1]
 [          0           1]
 [          0           0]
 [          0           1]
 [          0           1]
 [          1           1]
 [          0           0]]
top1: 0.083333336
top5: 0.5

Yolov5 6.2 原版输出

复制代码
pred: tensor([[6, 7, 1, 2, 9],
        [9, 2, 4, 6, 3],
        [7, 4, 0, 5, 9],
        [9, 8, 2, 6, 4],
        [6, 8, 9, 5, 2],
        [8, 7, 4, 2, 6],
        [9, 2, 4, 6, 7],
        [2, 9, 8, 0, 6],
        [8, 9, 6, 2, 7],
        [7, 4, 8, 6, 3],
        [4, 2, 1, 8, 9],
        [8, 9, 6, 2, 1]])
pred: torch.Size([12, 5])
targets: tensor([1, 1, 0, 2, 1, 2, 0, 2, 0, 2, 1, 0])
targets: torch.Size([12])
correct: torch.Size([12, 5])
acc: torch.Size([12, 2])
top1: 0.0833333358168602
top5: 0.5

文本代码是按照标签,即文件夹名字排序的,pred和target都是一一对应的,与Yolov5 6.2 原版相同

相关推荐
梨落秋霜13 小时前
Python入门篇【文件处理】
android·java·python
Java 码农14 小时前
RabbitMQ集群部署方案及配置指南03
java·python·rabbitmq
张登杰踩15 小时前
VIA标注格式转Labelme标注格式
python
Learner15 小时前
Python数据类型(四):字典
python
odoo中国16 小时前
Odoo 19 模块结构概述
开发语言·python·module·odoo·核心组件·py文件按
Jelena1577958579216 小时前
Java爬虫api接口测试
python
踩坑记录17 小时前
leetcode hot100 3.无重复字符的最长子串 medium 滑动窗口(双指针)
python·leetcode
诸神缄默不语18 小时前
Python处理Word文档完全指南:从基础到进阶
python
海棠AI实验室19 小时前
第四章 项目目录结构:src/、configs/、data/、tests/ 的黄金布局
python·项目目录结构
爱笑的眼睛1120 小时前
超越可视化:降维算法组件的深度解析与工程实践
java·人工智能·python·ai