YOLOv5 分类模型 数据集加载 3

YOLOv5 分类模型 数据集加载 3 自定义类别

flyfish

YOLOv5 分类模型 数据集加载 1 样本处理
YOLOv5 分类模型 数据集加载 2 切片处理
YOLOv5 分类模型的预处理(1) Resize 和 CenterCrop
YOLOv5 分类模型的预处理(2)ToTensor 和 Normalize
YOLOv5 分类模型 Top 1和Top 5 指标说明
YOLOv5 分类模型 Top 1和Top 5 指标实现

之前的处理方式是类别名字是文件夹名字,类别ID是按照文件夹名字的字母顺序

现在是类别名字是文件夹名字,按照文件列表名字顺序 例如

py 复制代码
classes_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 
'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']

n02086240 类别ID是0

n02087394 类别ID是1

代码处理是

py 复制代码
if classes_name is None or not classes_name:
    classes, class_to_idx = self.find_classes(self.root)
    print("not classes_name")

else:
    classes = classes_name
    class_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}
    print("is classes_name")

完整

py 复制代码
import time
from models.common import DetectMultiBackend
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np

import torch
from PIL import Image
import torchvision.transforms as transforms

import sys

classes_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']
              
class DatasetFolder:

    def __init__(
        self,
        root: str,

    ) -> None:
        self.root = root

        if classes_name is None or not classes_name:
            classes, class_to_idx = self.find_classes(self.root)
            print("not classes_name")

        else:
            classes = classes_name
            class_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}
            print("is classes_name")

        print("classes:",classes)
        
        print("class_to_idx:",class_to_idx)
        samples = self.make_dataset(self.root, class_to_idx)

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Optional[Dict[str, int]] = None,

    ) -> List[Tuple[str, int]]:

        directory = os.path.expanduser(directory)

        if class_to_idx is None:
            _, class_to_idx = self.find_classes(directory)
        elif not class_to_idx:
            raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

        instances = []
        available_classes = set()
        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    if 1:  # 验证:
                        item = path, class_index
                        instances.append(item)

                        if target_class not in available_classes:
                            available_classes.add(target_class)

        empty_classes = set(class_to_idx.keys()) - available_classes
        if empty_classes:
            msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "

        return instances

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:

        classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
        if not classes:
            raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def __getitem__(self, index: int) -> Tuple[Any, Any]:

        path, target = self.samples[index]
        sample = self.loader(path)

        return sample, target

    def __len__(self) -> int:
        return len(self.samples)

    def loader(self, path):
        print("path:", path)
        #img = cv2.imread(path)  # BGR HWC
        img=Image.open(path).convert("RGB") # RGB HWC
        return img


def time_sync():
    return time.time()

#sys.exit() 
dataset = DatasetFolder(root="/media/a/flyfish/source/yolov5/datasets/imagewoof/val")

#image, label=dataset[7]

#
weights = "/home/a/classes.pt"
device = "cpu"
model = DetectMultiBackend(weights, device=device, dnn=False, fp16=False)
model.eval()
print(model.names)
print(type(model.names))

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
def preprocess(images):
  

    #实现 PyTorch Resize
    target_size =224

    img_w = images.width
    img_h = images.height
    
    if(img_h >= img_w):# hw
 
        resize_img = images.resize((target_size, int(target_size * img_h / img_w)), Image.BILINEAR)
    else:
        resize_img = images.resize((int(target_size * img_w  / img_h),target_size), Image.BILINEAR)

    #实现 PyTorch CenterCrop
    width = resize_img.width
    height = resize_img.height

    center_x,center_y = width//2,height//2
    left = center_x - (target_size//2)
    top = center_y- (target_size//2)
    right =center_x +target_size//2
    bottom = center_y+target_size//2
    cropped_img = resize_img.crop((left, top, right, bottom))

    #实现 PyTorch ToTensor Normalize
    images = np.asarray(cropped_img)
    print("preprocess:",images.shape)
    images = images.astype('float32')
    images = (images/255-mean)/std
    images = images.transpose((2, 0, 1))# HWC to CHW
    print("preprocess:",images.shape)

    images = np.ascontiguousarray(images)
    images=torch.from_numpy(images)
    #images = images.unsqueeze(dim=0).float()
    return images

pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
# current batch size =1
for i, (images, labels) in enumerate(dataset):
    print("i:", i)
    im = preprocess(images)
    images = im.unsqueeze(0).to("cpu").float()
 
    print(images.shape)
    t1 = time_sync()
    images = images.to(device, non_blocking=True)
    t2 = time_sync()
    # dt[0] += t2 - t1

    y = model(images)
    y=y.numpy()
   
    #print("y:", y)
    t3 = time_sync()
    # dt[1] += t3 - t2
    #batch size >1 图像推理结果是二维的
    #y: [[     4.0855     -1.1707     -1.4998      -0.935     -1.9979      -2.258     -1.4691     -1.0867     -1.9042    -0.99979]]

    tmp1=y.argsort()[:,::-1][:, :5]

    #batch size =1 图像推理结果是一维的, 就要处理下argsort的维度
    #y: [     3.7441      -1.135     -1.1293     -0.9422     -1.6029     -2.0561      -1.025     -1.5842     -1.3952     -1.1824]
   
    #print("tmp1:", tmp1)
    pred.append(tmp1)
    #print("labels:", labels)
    targets.append(labels)

    #print("for pred:", pred)  # list
    #print("for targets:", targets)  # list
    # dt[2] += time_sync() - t3


pred, targets = np.concatenate(pred), np.array(targets)
print("pred:", pred)
print("pred:", pred.shape)
print("targets:", targets)
print("targets:", targets.shape)
correct = ((targets[:, None] == pred)).astype(np.float32)
print("correct:", correct.shape)
print("correct:", correct)
acc = np.stack((correct[:, 0], correct.max(1)), axis=1)  # (top1, top5) accuracy
print("acc:", acc.shape)
print("acc:", acc)
top = acc.mean(0)
print("top1:", top[0])
print("top5:", top[1])
相关推荐
qq_254674413 小时前
回归、分类、聚类
分类·回归·聚类
B站_计算机毕业设计之家5 小时前
深度血虚:Django水果检测识别系统 CNN卷积神经网络算法 python语言 计算机 大数据✅
python·深度学习·计算机视觉·信息可视化·分类·cnn·django
AI浩5 小时前
PAB-Mamba-YoLo: VSSM 辅助 YOLO 用于断奶仔猪攻击行为检测
yolo
元直数字电路验证12 小时前
感知机:乳腺癌分类实现 & K 均值聚类:从零实现
均值算法·分类·聚类
王哈哈^_^13 小时前
【完整源码+数据集】草莓数据集,yolov8草莓成熟度检测数据集 3207 张,草莓成熟度数据集,目标检测草莓识别算法系统实战教程
人工智能·算法·yolo·目标检测·计算机视觉·视觉检测·毕业设计
油泼辣子多加14 小时前
【实战】自然语言处理--长文本分类(3)HAN算法
算法·自然语言处理·分类
大大dxy大大20 小时前
机器学习实现逻辑回归-癌症分类预测
机器学习·分类·逻辑回归
王哈哈^_^1 天前
YOLOv11视觉检测实战:安全距离测算全解析
人工智能·数码相机·算法·yolo·计算机视觉·目标跟踪·视觉检测
深度学习lover1 天前
<数据集>yolo航拍交通目标识别数据集<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·航拍交通目标识别
Coovally AI模型快速验证1 天前
视觉语言模型(VLM)深度解析:如何用它来处理文档
人工智能·yolo·目标跟踪·语言模型·自然语言处理·开源