YOLOv5 分类模型 数据集加载 3

YOLOv5 分类模型 数据集加载 3 自定义类别

flyfish

YOLOv5 分类模型 数据集加载 1 样本处理
YOLOv5 分类模型 数据集加载 2 切片处理
YOLOv5 分类模型的预处理(1) Resize 和 CenterCrop
YOLOv5 分类模型的预处理(2)ToTensor 和 Normalize
YOLOv5 分类模型 Top 1和Top 5 指标说明
YOLOv5 分类模型 Top 1和Top 5 指标实现

之前的处理方式是类别名字是文件夹名字,类别ID是按照文件夹名字的字母顺序

现在是类别名字是文件夹名字,按照文件列表名字顺序 例如

py 复制代码
classes_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 
'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']

n02086240 类别ID是0

n02087394 类别ID是1

代码处理是

py 复制代码
if classes_name is None or not classes_name:
    classes, class_to_idx = self.find_classes(self.root)
    print("not classes_name")

else:
    classes = classes_name
    class_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}
    print("is classes_name")

完整

py 复制代码
import time
from models.common import DetectMultiBackend
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np

import torch
from PIL import Image
import torchvision.transforms as transforms

import sys

classes_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']
              
class DatasetFolder:

    def __init__(
        self,
        root: str,

    ) -> None:
        self.root = root

        if classes_name is None or not classes_name:
            classes, class_to_idx = self.find_classes(self.root)
            print("not classes_name")

        else:
            classes = classes_name
            class_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}
            print("is classes_name")

        print("classes:",classes)
        
        print("class_to_idx:",class_to_idx)
        samples = self.make_dataset(self.root, class_to_idx)

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Optional[Dict[str, int]] = None,

    ) -> List[Tuple[str, int]]:

        directory = os.path.expanduser(directory)

        if class_to_idx is None:
            _, class_to_idx = self.find_classes(directory)
        elif not class_to_idx:
            raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

        instances = []
        available_classes = set()
        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    if 1:  # 验证:
                        item = path, class_index
                        instances.append(item)

                        if target_class not in available_classes:
                            available_classes.add(target_class)

        empty_classes = set(class_to_idx.keys()) - available_classes
        if empty_classes:
            msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "

        return instances

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:

        classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
        if not classes:
            raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def __getitem__(self, index: int) -> Tuple[Any, Any]:

        path, target = self.samples[index]
        sample = self.loader(path)

        return sample, target

    def __len__(self) -> int:
        return len(self.samples)

    def loader(self, path):
        print("path:", path)
        #img = cv2.imread(path)  # BGR HWC
        img=Image.open(path).convert("RGB") # RGB HWC
        return img


def time_sync():
    return time.time()

#sys.exit() 
dataset = DatasetFolder(root="/media/a/flyfish/source/yolov5/datasets/imagewoof/val")

#image, label=dataset[7]

#
weights = "/home/a/classes.pt"
device = "cpu"
model = DetectMultiBackend(weights, device=device, dnn=False, fp16=False)
model.eval()
print(model.names)
print(type(model.names))

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
def preprocess(images):
  

    #实现 PyTorch Resize
    target_size =224

    img_w = images.width
    img_h = images.height
    
    if(img_h >= img_w):# hw
 
        resize_img = images.resize((target_size, int(target_size * img_h / img_w)), Image.BILINEAR)
    else:
        resize_img = images.resize((int(target_size * img_w  / img_h),target_size), Image.BILINEAR)

    #实现 PyTorch CenterCrop
    width = resize_img.width
    height = resize_img.height

    center_x,center_y = width//2,height//2
    left = center_x - (target_size//2)
    top = center_y- (target_size//2)
    right =center_x +target_size//2
    bottom = center_y+target_size//2
    cropped_img = resize_img.crop((left, top, right, bottom))

    #实现 PyTorch ToTensor Normalize
    images = np.asarray(cropped_img)
    print("preprocess:",images.shape)
    images = images.astype('float32')
    images = (images/255-mean)/std
    images = images.transpose((2, 0, 1))# HWC to CHW
    print("preprocess:",images.shape)

    images = np.ascontiguousarray(images)
    images=torch.from_numpy(images)
    #images = images.unsqueeze(dim=0).float()
    return images

pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
# current batch size =1
for i, (images, labels) in enumerate(dataset):
    print("i:", i)
    im = preprocess(images)
    images = im.unsqueeze(0).to("cpu").float()
 
    print(images.shape)
    t1 = time_sync()
    images = images.to(device, non_blocking=True)
    t2 = time_sync()
    # dt[0] += t2 - t1

    y = model(images)
    y=y.numpy()
   
    #print("y:", y)
    t3 = time_sync()
    # dt[1] += t3 - t2
    #batch size >1 图像推理结果是二维的
    #y: [[     4.0855     -1.1707     -1.4998      -0.935     -1.9979      -2.258     -1.4691     -1.0867     -1.9042    -0.99979]]

    tmp1=y.argsort()[:,::-1][:, :5]

    #batch size =1 图像推理结果是一维的, 就要处理下argsort的维度
    #y: [     3.7441      -1.135     -1.1293     -0.9422     -1.6029     -2.0561      -1.025     -1.5842     -1.3952     -1.1824]
   
    #print("tmp1:", tmp1)
    pred.append(tmp1)
    #print("labels:", labels)
    targets.append(labels)

    #print("for pred:", pred)  # list
    #print("for targets:", targets)  # list
    # dt[2] += time_sync() - t3


pred, targets = np.concatenate(pred), np.array(targets)
print("pred:", pred)
print("pred:", pred.shape)
print("targets:", targets)
print("targets:", targets.shape)
correct = ((targets[:, None] == pred)).astype(np.float32)
print("correct:", correct.shape)
print("correct:", correct)
acc = np.stack((correct[:, 0], correct.max(1)), axis=1)  # (top1, top5) accuracy
print("acc:", acc.shape)
print("acc:", acc)
top = acc.mean(0)
print("top1:", top[0])
print("top5:", top[1])
相关推荐
羊羊小栈1 天前
基于「YOLO目标检测 + 多模态AI分析」的遥感影像目标检测分析系统(vue+flask+数据集+模型训练)
人工智能·深度学习·yolo·目标检测·毕业设计·大作业
HenrySmale2 天前
05 回归问题和分类问题
分类·数据挖掘·回归
✎﹏赤子·墨筱晗♪2 天前
深入浅出LVS负载均衡群集:原理、分类与NAT模式实战部署
分类·负载均衡·lvs
victory04312 天前
wav2vec微调进行疾病语音分类任务
人工智能·分类·数据挖掘
大霸王龙2 天前
基于vLLM与YOLO的智能图像分类系统
yolo·分类·数据挖掘
m_136873 天前
Mac M 系列芯片 YOLOv8 部署教程(CPU/Metal 后端一键安装)
yolo·macos
Lululaurel3 天前
机器学习系统框架:核心分类、算法与应用全景解析
人工智能·算法·机器学习·ai·分类
格林威3 天前
机器视觉在半导体制造中有哪些检测应用
人工智能·数码相机·yolo·计算机视觉·视觉检测·制造·相机
月岛雫-3 天前
“单标签/多标签” vs “二分类/多分类”
人工智能·分类·数据挖掘
xchenhao3 天前
SciKit-Learn 全面分析分类任务 breast_cancer 数据集
python·机器学习·分类·数据集·scikit-learn·svm