利用 Python 的包管理和动态属性获取(`__init__.py` 文件和 `getattr` 函数)特性来实现工厂方法模式

Python 提供了许多灵活的特性,例如包的 __init__.py 文件和 getattr 函数,这些特性可以帮助我们实现工厂方法模式来动态地创建不同类型的数据集实例。

1. 背景介绍

在深度学习项目中,我们通常需要处理多种类型的数据集,例如 COCO、Pascal VOC 和自定义的交通数据集。为了简化和统一数据集的加载过程,我们可以利用 Python 的包管理和动态属性获取特性来实现工厂方法模式。

  • 包的 __init__.py 文件 :通过在包的 __init__.py 文件中导入模块,我们可以在初始化包时自动加载所有必要的类和函数。
  • getattr 函数getattr 函数允许我们动态地获取对象的属性或方法,这对于实现工厂方法模式非常有用,因为我们可以根据配置或输入动态地创建对象,而无需在代码中硬编码每种数据集的构建逻辑。

接下来,我们将通过具体的代码示例来展示如何使用这些特性来实现数据集的动态加载。

2. 模块和类的定义

在我们的项目中,数据集类被定义在 datasets 模块中。我们将定义一个 COCODataset 类,并在 datasets 模块的 __init__.py 文件中导入它。需要注意的是,COCODataset 只是众多数据集类中的一种,其他数据集类如 PascalVOCDatasetTrafficDataset 等也可以通过类似的方式定义和使用。

定义 COCODataset

datasets 模块中创建一个名为 coco.py 的文件,并定义 COCODataset 类。这个类继承自 torchvision.datasets.coco.CocoDetection,并添加了一些自定义逻辑。

python 复制代码
# datasets/coco.py
import torchvision

class COCODataset(torchvision.datasets.coco.CocoDetection):
    def __init__(self, ann_file, root, remove_images_without_annotations, transforms=None):
        super(COCODataset, self).__init__(root, ann_file)
        # 自定义逻辑...
  • __init__ 方法COCODataset 类的构造函数接受 ann_file(注释文件路径)、root(图像根目录)、remove_images_without_annotations(是否移除没有注释的图像)和 transforms(图像变换)四个参数。这些参数与后面 DatasetCatalogget 方法返回的 args 对应。
  • 详细实现见附录
导入 COCODataset

datasets 模块的 __init__.py 文件中导入 COCODataset 类。这样可以确保在使用 datasets 模块时,所有数据集类都已加载。

python 复制代码
# datasets/__init__.py
from .coco import COCODataset
from .voc import PascalVOCDataset
from .concat_dataset import ConcatDataset
from .traffic_dataset import TrafficDataset
from .carWinBiaoZhi_dataset import CarWinBiaoZhiDataset
from .carWinBiaoZhi_dataset_V2 import CarWinBiaoZhiDatasetV2
from .carWinBiaoZhi_dataset_V2_1 import CarWinBiaoZhiDatasetV2_1
from .GsData import CgTrafficData
from .GsData_xianQuan import CgTrafficDataWithXianQuan
from .GsData_1cls import CgTrafficData1Cls
from .GsData_ForSemi import CgTrafficDataSemi
from .GsRadarData import CgTrafficRadarData

__all__ = [
    "COCODataset", "ConcatDataset", "PascalVOCDataset", "TrafficDataset",
    "CarWinBiaoZhiDataset", "CarWinBiaoZhiDatasetV2", "CarWinBiaoZhiDatasetV2_1", 
    "CgTrafficData", "CgTrafficDataWithXianQuan", "CgTrafficDataSemi", 
    "CgTrafficRadarData", "CgTrafficData1Cls"
]

3. 使用 getattr 动态获取工厂方法

在构建数据集实例时,我们通过 getattr 函数动态获取工厂方法。以下是实现这一过程的核心代码:

python 复制代码
# build_dataset.py
from . import datasets as D

def build_dataset(dataset_list, transforms, dataset_catalog, is_train=True):
    if not isinstance(dataset_list, (list, tuple)):
        raise RuntimeError(
            "dataset_list 应该是一个字符串列表,得到的是 {}".format(dataset_list)
        )
    
    datasets = []  # 初始化数据集列表
    
    for dataset_name in dataset_list:
        # 从 dataset_catalog 中获取数据集信息
        data = dataset_catalog.get(dataset_name)
        
        # 获取数据集的工厂方法
        factory = getattr(D, data["factory"])
        
        # 获取数据集的参数
        args = data["args"]
        
        # 设置数据集的变换
        args["transforms"] = transforms
        
        # 使用工厂方法创建数据集实例
        dataset = factory(**args)
        
        # 将创建的数据集添加到列表中
        datasets.append(dataset)
    
    # 如果是测试模式,返回数据集列表
    if not is_train:
        return datasets
    
    # 如果是训练模式,将所有数据集合并为一个数据集
    dataset = datasets[0]
    if len(datasets) > 1:
        dataset = D.ConcatDataset(datasets)
    
    return [dataset]

4. 数据集目录管理 (DatasetCatalog)

为了集中管理数据集的路径和相关信息,我们定义了 DatasetCatalog 类。这个类包含了所有数据集的配置信息,并提供了一个静态方法 get 来获取特定数据集的配置信息。

python 复制代码
# paths_catalog.py
import os

class DatasetCatalog(object):
    DATA_DIR = "/home/Public_DataSets"
    DATASETS = {
        "coco_2017_train": {
            "img_dir": "coco/train2017",
            "ann_file": "coco/annotations/instances_train2017.json"
        },
        "voc_2007_train": {
            "data_dir": "voc/VOC2007",
            "split": "train"
        },
        # ... 其他数据集配置 ...
    }

    @staticmethod
    def get(name):
        if "coco" in name:
            data_dir = DatasetCatalog.DATA_DIR
            attrs = DatasetCatalog.DATASETS[name]
            args = dict(
                root=os.path.join(data_dir, attrs["img_dir"]),
                ann_file=os.path.join(data_dir, attrs["ann_file"]),
            )
            return dict(
                factory="COCODataset",
                args=args,
            )
        elif "voc" in name:
            data_dir = DatasetCatalog.DATA_DIR
            attrs = DatasetCatalog.DATASETS[name]
            args = dict(
                data_dir=os.path.join(data_dir, attrs["data_dir"]),
                split=attrs["split"],
            )
            return dict(
                factory="PascalVOCDataset",
                args=args,
            )
        # ... 其他数据集配置 ...
        raise RuntimeError("Dataset not available: {}".format(name))
说明

get 方法中,我们根据数据集名称动态生成配置字典。例如,对于 COCO 数据集:

python 复制代码
return dict(
    factory="COCODataset",
    args=args,
)
  • factory:指定数据集类的名称,在后续步骤中用于动态获取工厂方法。
  • args:包含构建数据集实例所需的参数。

5. COCO 数据集的举例说明

假设我们有一个名为 "coco_2017_train" 的数据集,我们希望使用 DatasetCatalog 和工厂方法来加载这个数据集。以下是具体的步骤:

  1. 定义数据集配置

    python 复制代码
    # paths_catalog.py 中的 DATASETS 字典
    DATASETS = {
        "coco_2017_train": {
            "img_dir": "coco/train2017",
            "ann_file": "coco/annotations/instances_train2017.json"
        },
        # ... 其他数据集配置 ...
    }
  2. 获取数据集配置

    python 复制代码
    data = DatasetCatalog.get("coco_2017_train")
  3. 动态获取工厂方法

    python 复制代码
    factory = getattr(D, data["factory"])
  4. 创建数据集实例

    python 复制代码
    args = data["args"]
    args["transforms"] = some_transform_function  # 假设我们有一个变换函数
    dataset = factory(**args)

通过这种方式,我们可以动态地加载 COCO 数据集,而无需硬编码每种数据集的构建逻辑。这种设计模式提高了代码的灵活性和可维护性,使得数据集的管理和加载更加方便。

附录: COCODataset 类完整实现
python 复制代码
# datasets/coco.py
import torch
import torchvision
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
from maskrcnn_benchmark.structures.keypoint import PersonKeypoints

min_keypoints_per_image = 10

def has_valid_annotation(anno):
    if len(anno) == 0:
        return False
    if all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno):
        return False
    if "keypoints" not in anno[0]:
        return True
    if sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) >= min_keypoints_per_image:
        return True
    return False

class COCODataset(torchvision.datasets.coco.CocoDetection):
    def __init__(self, ann_file, root, remove_images_without_annotations, transforms=None):
        super(COCODataset, self).__init__(root, ann_file)
        self.ids = sorted(self.ids)
        if remove_images_without_annotations:
            self.ids = [img_id for img_id in self.ids if has_valid_annotation(self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)))]
        self.categories = {cat['id']: cat['name'] for cat in self.coco.cats.values()}
        self.json_category_id_to_contiguous_id = {v: i + 1 for i, v in enumerate(self.coco.getCatIds())}
        self.contiguous_category_id_to_json_id = {v: k for k, v in self.json_category_id_to_contiguous_id.items()}
        self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
        self._transforms = transforms

    def __getitem__(self, idx):
        img, anno = super(COCODataset, self).__getitem__(idx)
        anno = [obj for obj in anno if obj["iscrowd"] == 0]
        boxes = [obj["bbox"] for obj in anno]
        boxes = torch.as_tensor(boxes).reshape(-1, 4)
        target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")
        classes = torch.tensor([self.json_category_id_to_contiguous_id[obj["category_id"]] for obj in anno])
        target.add_field("labels", classes)
        if anno and "segmentation" in anno[0]:
            masks = SegmentationMask([obj["segmentation"] for obj in anno], img.size, mode='poly')
            target.add_field("masks", masks)
        if anno and "keypoints" in anno[0]:
            keypoints = PersonKeypoints([obj["keypoints"] for obj in anno], img.size)
            target.add_field("keypoints", keypoints)
        target = target.clip_to_image(remove_empty=True)
        if self._transforms is not None:
            img, target = self._transforms(img, target)
        return img, target, idx

    def get_img_info(self, index):
        return self.coco.imgs[self.id_to_img_map[index]]
  • __init__ 方法:初始化数据集,加载注释,过滤无效注释,并设置类别和图像映射。
  • __getitem__ 方法:获取指定索引的图像和注释,应用可选的变换,并返回图像、目标和索引。
相关推荐
2401_8574396925 分钟前
SSM 架构下 Vue 电脑测评系统:为电脑性能评估赋能
开发语言·php
SoraLuna1 小时前
「Mac畅玩鸿蒙与硬件47」UI互动应用篇24 - 虚拟音乐控制台
开发语言·macos·ui·华为·harmonyos
xlsw_1 小时前
java全栈day20--Web后端实战(Mybatis基础2)
java·开发语言·mybatis
梧桐树04292 小时前
python常用内建模块:collections
python
Dream_Snowar2 小时前
速通Python 第三节
开发语言·python
高山我梦口香糖3 小时前
[react]searchParams转普通对象
开发语言·前端·javascript
信号处理学渣3 小时前
matlab画图,选择性显示legend标签
开发语言·matlab
红龙创客3 小时前
某狐畅游24校招-C++开发岗笔试(单选题)
开发语言·c++
蓝天星空3 小时前
Python调用open ai接口
人工智能·python
jasmine s3 小时前
Pandas
开发语言·python