Python 提供了许多灵活的特性,例如包的 __init__.py
文件和 getattr
函数,这些特性可以帮助我们实现工厂方法模式来动态地创建不同类型的数据集实例。
1. 背景介绍
在深度学习项目中,我们通常需要处理多种类型的数据集,例如 COCO、Pascal VOC 和自定义的交通数据集。为了简化和统一数据集的加载过程,我们可以利用 Python 的包管理和动态属性获取特性来实现工厂方法模式。
- 包的
__init__.py
文件 :通过在包的__init__.py
文件中导入模块,我们可以在初始化包时自动加载所有必要的类和函数。 getattr
函数 :getattr
函数允许我们动态地获取对象的属性或方法,这对于实现工厂方法模式非常有用,因为我们可以根据配置或输入动态地创建对象,而无需在代码中硬编码每种数据集的构建逻辑。
接下来,我们将通过具体的代码示例来展示如何使用这些特性来实现数据集的动态加载。
2. 模块和类的定义
在我们的项目中,数据集类被定义在 datasets
模块中。我们将定义一个 COCODataset
类,并在 datasets
模块的 __init__.py
文件中导入它。需要注意的是,COCODataset
只是众多数据集类中的一种,其他数据集类如 PascalVOCDataset
、TrafficDataset
等也可以通过类似的方式定义和使用。
定义 COCODataset
类
在 datasets
模块中创建一个名为 coco.py
的文件,并定义 COCODataset
类。这个类继承自 torchvision.datasets.coco.CocoDetection
,并添加了一些自定义逻辑。
python
# datasets/coco.py
import torchvision
class COCODataset(torchvision.datasets.coco.CocoDetection):
def __init__(self, ann_file, root, remove_images_without_annotations, transforms=None):
super(COCODataset, self).__init__(root, ann_file)
# 自定义逻辑...
__init__
方法 :COCODataset
类的构造函数接受ann_file
(注释文件路径)、root
(图像根目录)、remove_images_without_annotations
(是否移除没有注释的图像)和transforms
(图像变换)四个参数。这些参数与后面DatasetCatalog
中get
方法返回的args
对应。- 详细实现见附录
导入 COCODataset
类
在 datasets
模块的 __init__.py
文件中导入 COCODataset
类。这样可以确保在使用 datasets
模块时,所有数据集类都已加载。
python
# datasets/__init__.py
from .coco import COCODataset
from .voc import PascalVOCDataset
from .concat_dataset import ConcatDataset
from .traffic_dataset import TrafficDataset
from .carWinBiaoZhi_dataset import CarWinBiaoZhiDataset
from .carWinBiaoZhi_dataset_V2 import CarWinBiaoZhiDatasetV2
from .carWinBiaoZhi_dataset_V2_1 import CarWinBiaoZhiDatasetV2_1
from .GsData import CgTrafficData
from .GsData_xianQuan import CgTrafficDataWithXianQuan
from .GsData_1cls import CgTrafficData1Cls
from .GsData_ForSemi import CgTrafficDataSemi
from .GsRadarData import CgTrafficRadarData
__all__ = [
"COCODataset", "ConcatDataset", "PascalVOCDataset", "TrafficDataset",
"CarWinBiaoZhiDataset", "CarWinBiaoZhiDatasetV2", "CarWinBiaoZhiDatasetV2_1",
"CgTrafficData", "CgTrafficDataWithXianQuan", "CgTrafficDataSemi",
"CgTrafficRadarData", "CgTrafficData1Cls"
]
3. 使用 getattr
动态获取工厂方法
在构建数据集实例时,我们通过 getattr
函数动态获取工厂方法。以下是实现这一过程的核心代码:
python
# build_dataset.py
from . import datasets as D
def build_dataset(dataset_list, transforms, dataset_catalog, is_train=True):
if not isinstance(dataset_list, (list, tuple)):
raise RuntimeError(
"dataset_list 应该是一个字符串列表,得到的是 {}".format(dataset_list)
)
datasets = [] # 初始化数据集列表
for dataset_name in dataset_list:
# 从 dataset_catalog 中获取数据集信息
data = dataset_catalog.get(dataset_name)
# 获取数据集的工厂方法
factory = getattr(D, data["factory"])
# 获取数据集的参数
args = data["args"]
# 设置数据集的变换
args["transforms"] = transforms
# 使用工厂方法创建数据集实例
dataset = factory(**args)
# 将创建的数据集添加到列表中
datasets.append(dataset)
# 如果是测试模式,返回数据集列表
if not is_train:
return datasets
# 如果是训练模式,将所有数据集合并为一个数据集
dataset = datasets[0]
if len(datasets) > 1:
dataset = D.ConcatDataset(datasets)
return [dataset]
4. 数据集目录管理 (DatasetCatalog
)
为了集中管理数据集的路径和相关信息,我们定义了 DatasetCatalog
类。这个类包含了所有数据集的配置信息,并提供了一个静态方法 get
来获取特定数据集的配置信息。
python
# paths_catalog.py
import os
class DatasetCatalog(object):
DATA_DIR = "/home/Public_DataSets"
DATASETS = {
"coco_2017_train": {
"img_dir": "coco/train2017",
"ann_file": "coco/annotations/instances_train2017.json"
},
"voc_2007_train": {
"data_dir": "voc/VOC2007",
"split": "train"
},
# ... 其他数据集配置 ...
}
@staticmethod
def get(name):
if "coco" in name:
data_dir = DatasetCatalog.DATA_DIR
attrs = DatasetCatalog.DATASETS[name]
args = dict(
root=os.path.join(data_dir, attrs["img_dir"]),
ann_file=os.path.join(data_dir, attrs["ann_file"]),
)
return dict(
factory="COCODataset",
args=args,
)
elif "voc" in name:
data_dir = DatasetCatalog.DATA_DIR
attrs = DatasetCatalog.DATASETS[name]
args = dict(
data_dir=os.path.join(data_dir, attrs["data_dir"]),
split=attrs["split"],
)
return dict(
factory="PascalVOCDataset",
args=args,
)
# ... 其他数据集配置 ...
raise RuntimeError("Dataset not available: {}".format(name))
说明
在 get
方法中,我们根据数据集名称动态生成配置字典。例如,对于 COCO 数据集:
python
return dict(
factory="COCODataset",
args=args,
)
factory
:指定数据集类的名称,在后续步骤中用于动态获取工厂方法。args
:包含构建数据集实例所需的参数。
5. COCO 数据集的举例说明
假设我们有一个名为 "coco_2017_train"
的数据集,我们希望使用 DatasetCatalog
和工厂方法来加载这个数据集。以下是具体的步骤:
-
定义数据集配置:
python# paths_catalog.py 中的 DATASETS 字典 DATASETS = { "coco_2017_train": { "img_dir": "coco/train2017", "ann_file": "coco/annotations/instances_train2017.json" }, # ... 其他数据集配置 ... }
-
获取数据集配置:
pythondata = DatasetCatalog.get("coco_2017_train")
-
动态获取工厂方法:
pythonfactory = getattr(D, data["factory"])
-
创建数据集实例:
pythonargs = data["args"] args["transforms"] = some_transform_function # 假设我们有一个变换函数 dataset = factory(**args)
通过这种方式,我们可以动态地加载 COCO 数据集,而无需硬编码每种数据集的构建逻辑。这种设计模式提高了代码的灵活性和可维护性,使得数据集的管理和加载更加方便。
附录: COCODataset
类完整实现
python
# datasets/coco.py
import torch
import torchvision
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
from maskrcnn_benchmark.structures.keypoint import PersonKeypoints
min_keypoints_per_image = 10
def has_valid_annotation(anno):
if len(anno) == 0:
return False
if all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno):
return False
if "keypoints" not in anno[0]:
return True
if sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) >= min_keypoints_per_image:
return True
return False
class COCODataset(torchvision.datasets.coco.CocoDetection):
def __init__(self, ann_file, root, remove_images_without_annotations, transforms=None):
super(COCODataset, self).__init__(root, ann_file)
self.ids = sorted(self.ids)
if remove_images_without_annotations:
self.ids = [img_id for img_id in self.ids if has_valid_annotation(self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)))]
self.categories = {cat['id']: cat['name'] for cat in self.coco.cats.values()}
self.json_category_id_to_contiguous_id = {v: i + 1 for i, v in enumerate(self.coco.getCatIds())}
self.contiguous_category_id_to_json_id = {v: k for k, v in self.json_category_id_to_contiguous_id.items()}
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
self._transforms = transforms
def __getitem__(self, idx):
img, anno = super(COCODataset, self).__getitem__(idx)
anno = [obj for obj in anno if obj["iscrowd"] == 0]
boxes = [obj["bbox"] for obj in anno]
boxes = torch.as_tensor(boxes).reshape(-1, 4)
target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")
classes = torch.tensor([self.json_category_id_to_contiguous_id[obj["category_id"]] for obj in anno])
target.add_field("labels", classes)
if anno and "segmentation" in anno[0]:
masks = SegmentationMask([obj["segmentation"] for obj in anno], img.size, mode='poly')
target.add_field("masks", masks)
if anno and "keypoints" in anno[0]:
keypoints = PersonKeypoints([obj["keypoints"] for obj in anno], img.size)
target.add_field("keypoints", keypoints)
target = target.clip_to_image(remove_empty=True)
if self._transforms is not None:
img, target = self._transforms(img, target)
return img, target, idx
def get_img_info(self, index):
return self.coco.imgs[self.id_to_img_map[index]]
__init__
方法:初始化数据集,加载注释,过滤无效注释,并设置类别和图像映射。__getitem__
方法:获取指定索引的图像和注释,应用可选的变换,并返回图像、目标和索引。