YOLOv5 分类模型 数据集加载
flyfish
数据集的加载 python实现,不使用torch库
简化实现
py
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
class DatasetFolder:
def __init__(
self,
root: str,
) -> None:
self.root=root
classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx)
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
@staticmethod
def make_dataset(
directory: str,
class_to_idx: Optional[Dict[str, int]] = None,
) -> List[Tuple[str, int]]:
directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = self.find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if 1:#验证:
item = path, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
return instances
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def __getitem__(self, index: int) -> Tuple[Any, Any]:
path, target = self.samples[index]
sample = self.loader(path)
return sample, target
def __len__(self) -> int:
return len(self.samples)
dataset = DatasetFolder(root="/media/a/flyfish/test");
print(dataset)
print("dataset.targets:",dataset.targets)
print("dataset.classes:",dataset.classes)
print("samples:",dataset.samples)
find_classes
将标签索引和标签内容对应
0,1,2
是标签索引
'n01440764', 'n01443537', 'n01484850'
是类别名字也是文件夹名字
按照升序排序
dataset.targets: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
dataset.classes: ['n01440764', 'n01443537', 'n01484850']
样本中一个是图像文件的绝对路径,后面的是标签
samples: [('/media/a/flyfish/test/n01440764/ILSVRC2012_val_00000293.JPEG', 0),
('/media/a/flyfish/test/n01440764/ILSVRC2012_val_00002138.JPEG', 0),
('/media/a/flyfish/test/n01440764/ILSVRC2012_val_00003014.JPEG', 0),
('/media/a/flyfish/test/n01440764/ILSVRC2012_val_00006697.JPEG', 0),
('/media/a/flyfish/test/n01443537/ILSVRC2012_val_00000236.JPEG', 1),
('/media/a/flyfish/test/n01443537/ILSVRC2012_val_00000262.JPEG', 1),
('/media/a/flyfish/test/n01443537/ILSVRC2012_val_00000307.JPEG', 1),
('/media/a/flyfish/test/n01443537/ILSVRC2012_val_00000994.JPEG', 1),
('/media/a/flyfish/test/n01484850/ILSVRC2012_val_00002338.JPEG', 2),
('/media/a/flyfish/test/n01484850/ILSVRC2012_val_00002752.JPEG', 2),
('/media/a/flyfish/test/n01484850/ILSVRC2012_val_00004311.JPEG', 2),
('/media/a/flyfish/test/n01484850/ILSVRC2012_val_00004329.JPEG', 2)]
可以功能丰富一些,例如检测文件的扩展名是否是支持的图像文件
py
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
"""检查文件是否为允许的扩展名
"""
return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
def is_image_file(filename: str) -> bool:
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
测试
py
r=is_image_file("/media/a/flyfish/data/imagewoof/val/n02086240/1.jpeg");
print(r)#True
r=is_image_file("/media/a/flyfish/data/imagewoof/val/n02086240/1.txt");
print(r)#False