YOLOv5 分类模型 Top 1和Top 5 指标实现

YOLOv5 分类模型 Top 1和Top 5 指标实现

flyfish

复制代码
import time
from models.common import DetectMultiBackend
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np

import torch
from utils.augmentations import classify_transforms


class DatasetFolder:

    def __init__(
        self,
        root: str,

    ) -> None:
        self.root = root
        classes, class_to_idx = self.find_classes(self.root)
        samples = self.make_dataset(self.root, class_to_idx)

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Optional[Dict[str, int]] = None,

    ) -> List[Tuple[str, int]]:

        directory = os.path.expanduser(directory)

        if class_to_idx is None:
            _, class_to_idx = self.find_classes(directory)
        elif not class_to_idx:
            raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

        instances = []
        available_classes = set()
        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    if 1:  # 验证:
                        item = path, class_index
                        instances.append(item)

                        if target_class not in available_classes:
                            available_classes.add(target_class)

        empty_classes = set(class_to_idx.keys()) - available_classes
        if empty_classes:
            msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "

        return instances

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:

        classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
        if not classes:
            raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def __getitem__(self, index: int) -> Tuple[Any, Any]:

        path, target = self.samples[index]
        sample = self.loader(path)

        return sample, target

    def __len__(self) -> int:
        return len(self.samples)

    def loader(self, path):
        print("path:", path)
        img = cv2.imread(path)  # BGR HWC
        return img


def time_sync():
    return time.time()


dataset = DatasetFolder(root="/media/flyfish/test/val")

# image, label=dataset[7]
# print(image.shape)
#
weights = "/media/flyfish/yolov5-6.2/classes10.pt"
device = "cpu"
model = DetectMultiBackend(weights, device=device, dnn=False, fp16=False)
model.eval()

transforms = classify_transforms(224)

pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
# current batch size =1
for i, (images, labels) in enumerate(dataset):
    print("i:", i)
    print(images.shape, labels)
    im = cv2.cvtColor(images, cv2.COLOR_BGR2RGB)
    im = transforms(im)
    images = im.unsqueeze(0).to("cpu")
 
    print(images.shape)


        
    t1 = time_sync()
    images = images.to(device, non_blocking=True)
    t2 = time_sync()
    # dt[0] += t2 - t1

    y = model(images)
    y=y.numpy()
   
    print("y:", y)
    t3 = time_sync()
    # dt[1] += t3 - t2

    tmp1=y.argsort()[:,::-1][:, :5]
   
    print("tmp1:", tmp1)
    pred.append(tmp1)

    print("labels:", labels)

    
    targets.append(labels)

    print("for pred:", pred)  # list
    print("for targets:", targets)  # list

    # dt[2] += time_sync() - t3


pred, targets = np.concatenate(pred), np.array(targets)
print("pred:", pred)
print("pred:", pred.shape)
print("targets:", targets)
print("targets:", targets.shape)
correct = ((targets[:, None] == pred)).astype(np.float32)
print("correct:", correct.shape)
print("correct:", correct)
acc = np.stack((correct[:, 0], correct.max(1)), axis=1)  # (top1, top5) accuracy
print("acc:", acc.shape)
print("acc:", acc)
top = acc.mean(0)
print("top1:", top[0])
print("top5:", top[1])

输出

复制代码
pred: [[7 4 0 5 9]
 [9 2 4 6 7]
 [8 9 6 2 1]
 [8 9 6 2 7]
 [9 2 4 6 3]
 [6 7 1 2 9]
 [4 2 1 8 9]
 [6 8 9 5 2]
 [8 7 4 2 6]
 [9 8 2 6 4]
 [2 9 8 0 6]
 [7 4 8 6 3]]
pred: (12, 5)
targets: [0 0 0 0 1 1 1 1 2 2 2 2]
targets: (12,)
correct: (12, 5)
correct: [[          0           0           1           0           0]
 [          0           0           0           0           0]
 [          0           0           0           0           0]
 [          0           0           0           0           0]
 [          0           0           0           0           0]
 [          0           0           1           0           0]
 [          0           0           1           0           0]
 [          0           0           0           0           0]
 [          0           0           0           1           0]
 [          0           0           1           0           0]
 [          1           0           0           0           0]
 [          0           0           0           0           0]]
acc: (12, 2)
acc: [[          0           1]
 [          0           0]
 [          0           0]
 [          0           0]
 [          0           0]
 [          0           1]
 [          0           1]
 [          0           0]
 [          0           1]
 [          0           1]
 [          1           1]
 [          0           0]]
top1: 0.083333336
top5: 0.5

Yolov5 6.2 原版输出

复制代码
pred: tensor([[6, 7, 1, 2, 9],
        [9, 2, 4, 6, 3],
        [7, 4, 0, 5, 9],
        [9, 8, 2, 6, 4],
        [6, 8, 9, 5, 2],
        [8, 7, 4, 2, 6],
        [9, 2, 4, 6, 7],
        [2, 9, 8, 0, 6],
        [8, 9, 6, 2, 7],
        [7, 4, 8, 6, 3],
        [4, 2, 1, 8, 9],
        [8, 9, 6, 2, 1]])
pred: torch.Size([12, 5])
targets: tensor([1, 1, 0, 2, 1, 2, 0, 2, 0, 2, 1, 0])
targets: torch.Size([12])
correct: torch.Size([12, 5])
acc: torch.Size([12, 2])
top1: 0.0833333358168602
top5: 0.5

文本代码是按照标签,即文件夹名字排序的,pred和target都是一一对应的,与Yolov5 6.2 原版相同

相关推荐
weixin_5806140028 分钟前
如何提取SQL日期中的年份_使用YEAR或EXTRACT函数
jvm·数据库·python
2301_8135995536 分钟前
SQL生产环境规范_数据库使用最佳实践
jvm·数据库·python
李可以量化36 分钟前
QMT 量化实战:用 Python 实现线性回归通道,精准识别趋势中的支撑与压力(下)
python·qmt·量化 qmt ptrade
a95114164243 分钟前
Go 中通过 channel 传递切片时的数据竞争与深拷贝解决方案
jvm·数据库·python
Dxy123931021644 分钟前
Python 使用正则表达式将多个空格替换为一个空格
开发语言·python·正则表达式
qq_189807031 小时前
如何修改RAC数据库名_NID工具在集群环境下的改名步骤
jvm·数据库·python
zhangchaoxies1 小时前
如何检测SQL注入风险_利用模糊测试技术发现漏洞
jvm·数据库·python
Luca_kill2 小时前
MCP数据采集革命:从传统爬虫到智能代理的技术进化
爬虫·python·ai·数据采集·mcp·webscraping·集蜂云
zhangchaoxies2 小时前
CSS如何实现响应式弹性网格布局_配合media query修改flex-wrap属性
jvm·数据库·python
ZC跨境爬虫2 小时前
Scrapy分布式爬虫(单机模拟多节点):豆瓣Top250项目设置与数据流全解析
分布式·爬虫·python·scrapy