【代码整理】基于COCO格式的pytorch Dataset类实现

import模块

python 复制代码
import numpy as np
import torch
from functools import partial
from PIL import Image
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import random
import albumentations as A
from pycocotools.coco import COCO
import os
import cv2
import matplotlib.pyplot as plt

基于albumentations库自定义数据预处理/数据增强

python 复制代码
class Transform():
    '''数据预处理/数据增强(基于albumentations库)
    '''
    def __init__(self, imgSize):
        maxSize = max(imgSize[0], imgSize[1])
        # 训练时增强
        self.trainTF = A.Compose([
                A.BBoxSafeRandomCrop(p=0.5),
                # 最长边限制为imgSize
                A.LongestMaxSize(max_size=maxSize),
                A.HorizontalFlip(p=0.5),
                # 参数:随机色调、饱和度、值变化
                A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=False, p=0.5),
                # 随机明亮对比度
                A.RandomBrightnessContrast(p=0.2),   
                # 高斯噪声
                A.GaussNoise(var_limit=(0.05, 0.09), p=0.4),     
                A.OneOf([
                    # 使用随机大小的内核将运动模糊应用于输入图像
                    A.MotionBlur(p=0.2),   
                    # 中值滤波
                    A.MedianBlur(blur_limit=3, p=0.1),    
                    # 使用随机大小的内核模糊输入图像
                    A.Blur(blur_limit=3, p=0.1),  
                ], p=0.2),
                # 较短的边做padding
                A.PadIfNeeded(imgSize[0], imgSize[1], border_mode=cv2.BORDER_CONSTANT, value=[0,0,0]),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ],
            bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.1, label_fields=['category_ids']),
            )
        # 验证时增强
        self.validTF = A.Compose([
                # 最长边限制为imgSize
                A.LongestMaxSize(max_size=maxSize),
                # 较短的边做padding
                A.PadIfNeeded(imgSize[0], imgSize[1], border_mode=0, mask_value=[0,0,0]),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ],
            bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.1, label_fields=['category_ids']),
            )

自定义数据集读取类COCODataset实现

python 复制代码
class COCODataset(Dataset):

    def __init__(self, annPath, imgDir, inputShape=[800, 600], trainMode=True):
        '''__init__() 为默认构造函数,传入数据集类别(训练或测试),以及数据集路径

        Args:
            :param annPath:     COCO annotation 文件路径
            :param imgDir:      图像的根目录
            :param inputShape: 网络要求输入的图像尺寸
            :param trainMode:   训练集/测试集

        Returns:
            FRCNNDataset
        '''      
        self.mode = trainMode
        self.tf = Transform(imgSize=inputShape)
        self.imgDir = imgDir
        self.annPath = annPath
        self.DataNums = len(os.listdir(imgDir))
        # 为实例注释初始化COCO的API
        self.coco=COCO(annPath)
        # 获取数据集中所有图像对应的imgId
        self.imgIds = list(self.coco.imgs.keys())

    def __len__(self):
        '''重载data.Dataset父类方法, 返回数据集大小
        '''
        return len(self.imgIds)

    def __getitem__(self, index):
        '''重载data.Dataset父类方法, 获取数据集中数据内容
           这里通过pycocotools来读取图像和标签
        '''   
        # 通过imgId获取图像信息imgInfo: 例:{'id': 12465, 'license': 1, 'height': 375, 'width': 500, 'file_name': '2011_003115.jpg'}
        imgId = self.imgIds[index]
        imgInfo = self.coco.loadImgs(imgId)[0]
        # 载入图像 (通过imgInfo获取图像名,得到图像路径)               
        image = Image.open(os.path.join(self.imgDir, imgInfo['file_name']))
        image = np.array(image.convert('RGB'))
        # 得到图像里包含的BBox的所有id
        imgAnnIds = self.coco.getAnnIds(imgIds=imgId)   
        # 通过BBox的id找到对应的BBox信息
        anns = self.coco.loadAnns(imgAnnIds) 
        # 获取BBox的坐标和类别
        labels, boxes = [], []
        for ann in anns:
            labelName = ann['category_id']
            labels.append(labelName)
            boxes.append(ann['bbox'])
        labels = np.array(labels)
        boxes = np.array(boxes)
        
        # 训练/验证时的数据增强各不相同
        if(self.mode):
            # albumentation的图像维度得是[W,H,C]
            transformed = self.tf.trainTF(image=image, bboxes=boxes, category_ids=labels)
        else:
            transformed = self.tf.validTF(image=image, bboxes=boxes, category_ids=labels)
        # 这里的box是coco格式(xywh)
        image, box, label = transformed['image'], transformed['bboxes'], transformed['category_ids']
        return image.transpose(2,0,1), np.array(box), np.array(label)

其他

python 复制代码
# DataLoader中collate_fn参数使用
# 由于检测数据集每张图像上的目标数量不一
# 因此需要自定义的如何组织一个batch里输出的内容
def frcnn_dataset_collate(batch):
    images = []
    bboxes = []
    labels = []
    for img, box, label in batch:
        images.append(img)
        bboxes.append(box)
        labels.append(label)
    images = torch.from_numpy(np.array(images))
    return images, bboxes, labels



# 设置Dataloader的种子
# DataLoader中worker_init_fn参数使
# 为每个 worker 设置了一个基于初始种子和 worker ID 的独特的随机种子, 这样每个 worker 将产生不同的随机数序列,从而有助于数据加载过程的随机性和多样性
def worker_init_fn(worker_id, seed):
    worker_seed = worker_id + seed
    random.seed(worker_seed)
    np.random.seed(worker_seed)
    torch.manual_seed(worker_seed)


# 固定全局随机数种子
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

batch数据集可视化

python 复制代码
def visBatch(dataLoader:DataLoader):
    '''可视化训练集一个batch
    Args:
        dataLoader: torch的data.DataLoader
    Retuens:
        None     
    '''
    catName = {1:'person', 2:'bicycle', 3:'car', 4:'motorcycle', 5:'airplane', 6:'bus',
               7:'train', 8:'truck', 9:'boat', 10:'traffic light', 11:'fire hydrant',
               13:'stop sign', 14:'parking meter', 15:'bench', 16:'bird', 17:'cat', 18:'dog',
               19:'horse', 20:'sheep', 21:'cow', 22:'elephant', 23:'bear', 24:'zebra', 25:'giraffe',
               27:'backpack', 28:'umbrella', 31:'handbag', 32:'tie', 33:'suitcase', 34:'frisbee',
               35:'skis', 36:'snowboard', 37:'sports ball', 38:'kite', 39:'baseball bat',
               40:'baseball glove', 41:'skateboard', 42:'surfboard', 43:'tennis racket',
               44:'bottle', 46:'wine glass', 47:'cup', 48:'fork', 49:'knife', 50:'spoon', 51:'bowl',
               52:'banana', 53:'apple', 54:'sandwich', 55:'orange', 56:'broccoli', 57:'carrot',
               58:'hot dog', 59:'pizza', 60:'donut', 61:'cake', 62:'chair', 63:'couch',
               64:'potted plant', 65:'bed', 67:'dining table', 70:'toilet', 72:'tv', 73:'laptop',
               74:'mouse', 75:'remote', 76:'keyboard', 77:'cell phone', 78:'microwave',
               79:'oven', 80:'toaster', 81:'sink', 82:'refrigerator', 84:'book', 85:'clock',
               86:'vase', 87:'scissors', 88:'teddy bear', 89:'hair drier', 90:'toothbrush'}
    
    for step, batch in enumerate(dataLoader):
        images, boxes, labels = batch[0], batch[1], batch[2]
        # 只可视化一个batch的图像:
        if step > 0: break
        # 图像均值
        mean = np.array([0.485, 0.456, 0.406]) 
        # 标准差
        std = np.array([[0.229, 0.224, 0.225]]) 
        plt.figure(figsize = (8,8))
        for idx, imgBoxLabel in enumerate(zip(images, boxes, labels)):
            img, box, label = imgBoxLabel
            ax = plt.subplot(4,4,idx+1)
            img = img.numpy().transpose((1,2,0))
            # 由于在数据预处理时我们对数据进行了标准归一化,可视化的时候需要将其还原
            img = img * std + mean
            for instBox, instLabel in zip(box, label):
                x, y, w, h = round(instBox[0]),round(instBox[1]), round(instBox[2]), round(instBox[3])
                # 显示框
                ax.add_patch(plt.Rectangle((x, y), w, h, color='blue', fill=False, linewidth=2))
                # 显示类别
                ax.text(x, y, catName[instLabel], bbox={'facecolor':'white', 'alpha':0.5})
            plt.imshow(img)
            # 在图像上方展示对应的标签
            # 取消坐标轴
            plt.axis("off")
             # 微调行间距
            plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.05, hspace=0.05)
        plt.show()

example

python 复制代码
# for test only:
if __name__ == "__main__":
    # 固定随机种子
    seed = 23
    seed_everything(seed)
    # BatcchSize
    BS = 16
    # 图像尺寸
    imgSize = [800, 800]

    trainAnnPath = "E:/datasets/Universal/COCO2017/annotations/instances_train2017.json"
    testAnnPath = "E:/datasets/Universal/COCO2017/annotations/instances_val2017.json"
    imgDir =  "E:/datasets/Universal/COCO2017/train2017"
    # 自定义数据集读取类
    trainDataset = COCODataset(trainAnnPath, imgDir, imgSize, trainMode=True)
    trainDataLoader = DataLoader(trainDataset, shuffle=True, batch_size = BS, num_workers=2, pin_memory=True,
                                    collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))
    # validDataset = COCODataset(testAnnPath, imgDir, imgSize, trainMode=False)
    # validDataLoader = DataLoader(validDataset, shuffle=True, batch_size = BS, num_workers = 1, pin_memory=True, 
                                  # collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))



    print(f'训练集大小 : {trainDataset.__len__()}')
    visBatch(trainDataLoader)
    for step, batch in enumerate(trainDataLoader):
        images, boxes, labels = batch[0], batch[1], batch[2]
        # torch.Size([bs, 3, 800, 800])
        print(f'images.shape : {images.shape}')   
        # 列表形式,因为每个框里的实例数量不一,所以每个列表里的box数量不一
        print(f'len(boxes) : {len(boxes)}')     
        # 列表形式,因为每个框里的实例数量不一,所以每个列表里的label数量不一  
        print(f'len(labels) : {len(labels)}')     
        break

输出

bash 复制代码
images.shape : torch.Size([16, 3, 800, 800])
len(boxes) : 16
len(labels) : 16
相关推荐
cdut_suye2 分钟前
Linux工具使用指南:从apt管理、gcc编译到makefile构建与gdb调试
java·linux·运维·服务器·c++·人工智能·python
开发者每周简报22 分钟前
微软的AI转型故事
人工智能·microsoft
dundunmm25 分钟前
机器学习之scikit-learn(简称 sklearn)
python·算法·机器学习·scikit-learn·sklearn·分类算法
古希腊掌管学习的神25 分钟前
[机器学习]sklearn入门指南(1)
人工智能·python·算法·机器学习·sklearn
一道微光39 分钟前
Mac的M2芯片运行lightgbm报错,其他python包可用,x86_x64架构运行
开发语言·python·macos
普密斯科技1 小时前
手机外观边框缺陷视觉检测智慧方案
人工智能·计算机视觉·智能手机·自动化·视觉检测·集成测试
四口鲸鱼爱吃盐1 小时前
Pytorch | 利用AI-FGTM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
lishanlu1361 小时前
Pytorch分布式训练
人工智能·ddp·pytorch并行训练
是娜个二叉树!1 小时前
图像处理基础 | 格式转换.rgb转.jpg 灰度图 python
开发语言·python
互联网杂货铺1 小时前
Postman接口测试:全局变量/接口关联/加密/解密
自动化测试·软件测试·python·测试工具·职场和发展·测试用例·postman