【代码整理】基于COCO格式的pytorch Dataset类实现

import模块

python 复制代码
import numpy as np
import torch
from functools import partial
from PIL import Image
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import random
import albumentations as A
from pycocotools.coco import COCO
import os
import cv2
import matplotlib.pyplot as plt

基于albumentations库自定义数据预处理/数据增强

python 复制代码
class Transform():
    '''数据预处理/数据增强(基于albumentations库)
    '''
    def __init__(self, imgSize):
        maxSize = max(imgSize[0], imgSize[1])
        # 训练时增强
        self.trainTF = A.Compose([
                A.BBoxSafeRandomCrop(p=0.5),
                # 最长边限制为imgSize
                A.LongestMaxSize(max_size=maxSize),
                A.HorizontalFlip(p=0.5),
                # 参数:随机色调、饱和度、值变化
                A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=False, p=0.5),
                # 随机明亮对比度
                A.RandomBrightnessContrast(p=0.2),   
                # 高斯噪声
                A.GaussNoise(var_limit=(0.05, 0.09), p=0.4),     
                A.OneOf([
                    # 使用随机大小的内核将运动模糊应用于输入图像
                    A.MotionBlur(p=0.2),   
                    # 中值滤波
                    A.MedianBlur(blur_limit=3, p=0.1),    
                    # 使用随机大小的内核模糊输入图像
                    A.Blur(blur_limit=3, p=0.1),  
                ], p=0.2),
                # 较短的边做padding
                A.PadIfNeeded(imgSize[0], imgSize[1], border_mode=cv2.BORDER_CONSTANT, value=[0,0,0]),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ],
            bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.1, label_fields=['category_ids']),
            )
        # 验证时增强
        self.validTF = A.Compose([
                # 最长边限制为imgSize
                A.LongestMaxSize(max_size=maxSize),
                # 较短的边做padding
                A.PadIfNeeded(imgSize[0], imgSize[1], border_mode=0, mask_value=[0,0,0]),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ],
            bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.1, label_fields=['category_ids']),
            )

自定义数据集读取类COCODataset实现

python 复制代码
class COCODataset(Dataset):

    def __init__(self, annPath, imgDir, inputShape=[800, 600], trainMode=True):
        '''__init__() 为默认构造函数,传入数据集类别(训练或测试),以及数据集路径

        Args:
            :param annPath:     COCO annotation 文件路径
            :param imgDir:      图像的根目录
            :param inputShape: 网络要求输入的图像尺寸
            :param trainMode:   训练集/测试集

        Returns:
            FRCNNDataset
        '''      
        self.mode = trainMode
        self.tf = Transform(imgSize=inputShape)
        self.imgDir = imgDir
        self.annPath = annPath
        self.DataNums = len(os.listdir(imgDir))
        # 为实例注释初始化COCO的API
        self.coco=COCO(annPath)
        # 获取数据集中所有图像对应的imgId
        self.imgIds = list(self.coco.imgs.keys())

    def __len__(self):
        '''重载data.Dataset父类方法, 返回数据集大小
        '''
        return len(self.imgIds)

    def __getitem__(self, index):
        '''重载data.Dataset父类方法, 获取数据集中数据内容
           这里通过pycocotools来读取图像和标签
        '''   
        # 通过imgId获取图像信息imgInfo: 例:{'id': 12465, 'license': 1, 'height': 375, 'width': 500, 'file_name': '2011_003115.jpg'}
        imgId = self.imgIds[index]
        imgInfo = self.coco.loadImgs(imgId)[0]
        # 载入图像 (通过imgInfo获取图像名,得到图像路径)               
        image = Image.open(os.path.join(self.imgDir, imgInfo['file_name']))
        image = np.array(image.convert('RGB'))
        # 得到图像里包含的BBox的所有id
        imgAnnIds = self.coco.getAnnIds(imgIds=imgId)   
        # 通过BBox的id找到对应的BBox信息
        anns = self.coco.loadAnns(imgAnnIds) 
        # 获取BBox的坐标和类别
        labels, boxes = [], []
        for ann in anns:
            labelName = ann['category_id']
            labels.append(labelName)
            boxes.append(ann['bbox'])
        labels = np.array(labels)
        boxes = np.array(boxes)
        
        # 训练/验证时的数据增强各不相同
        if(self.mode):
            # albumentation的图像维度得是[W,H,C]
            transformed = self.tf.trainTF(image=image, bboxes=boxes, category_ids=labels)
        else:
            transformed = self.tf.validTF(image=image, bboxes=boxes, category_ids=labels)
        # 这里的box是coco格式(xywh)
        image, box, label = transformed['image'], transformed['bboxes'], transformed['category_ids']
        return image.transpose(2,0,1), np.array(box), np.array(label)

其他

python 复制代码
# DataLoader中collate_fn参数使用
# 由于检测数据集每张图像上的目标数量不一
# 因此需要自定义的如何组织一个batch里输出的内容
def frcnn_dataset_collate(batch):
    images = []
    bboxes = []
    labels = []
    for img, box, label in batch:
        images.append(img)
        bboxes.append(box)
        labels.append(label)
    images = torch.from_numpy(np.array(images))
    return images, bboxes, labels



# 设置Dataloader的种子
# DataLoader中worker_init_fn参数使
# 为每个 worker 设置了一个基于初始种子和 worker ID 的独特的随机种子, 这样每个 worker 将产生不同的随机数序列,从而有助于数据加载过程的随机性和多样性
def worker_init_fn(worker_id, seed):
    worker_seed = worker_id + seed
    random.seed(worker_seed)
    np.random.seed(worker_seed)
    torch.manual_seed(worker_seed)


# 固定全局随机数种子
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

batch数据集可视化

python 复制代码
def visBatch(dataLoader:DataLoader):
    '''可视化训练集一个batch
    Args:
        dataLoader: torch的data.DataLoader
    Retuens:
        None     
    '''
    catName = {1:'person', 2:'bicycle', 3:'car', 4:'motorcycle', 5:'airplane', 6:'bus',
               7:'train', 8:'truck', 9:'boat', 10:'traffic light', 11:'fire hydrant',
               13:'stop sign', 14:'parking meter', 15:'bench', 16:'bird', 17:'cat', 18:'dog',
               19:'horse', 20:'sheep', 21:'cow', 22:'elephant', 23:'bear', 24:'zebra', 25:'giraffe',
               27:'backpack', 28:'umbrella', 31:'handbag', 32:'tie', 33:'suitcase', 34:'frisbee',
               35:'skis', 36:'snowboard', 37:'sports ball', 38:'kite', 39:'baseball bat',
               40:'baseball glove', 41:'skateboard', 42:'surfboard', 43:'tennis racket',
               44:'bottle', 46:'wine glass', 47:'cup', 48:'fork', 49:'knife', 50:'spoon', 51:'bowl',
               52:'banana', 53:'apple', 54:'sandwich', 55:'orange', 56:'broccoli', 57:'carrot',
               58:'hot dog', 59:'pizza', 60:'donut', 61:'cake', 62:'chair', 63:'couch',
               64:'potted plant', 65:'bed', 67:'dining table', 70:'toilet', 72:'tv', 73:'laptop',
               74:'mouse', 75:'remote', 76:'keyboard', 77:'cell phone', 78:'microwave',
               79:'oven', 80:'toaster', 81:'sink', 82:'refrigerator', 84:'book', 85:'clock',
               86:'vase', 87:'scissors', 88:'teddy bear', 89:'hair drier', 90:'toothbrush'}
    
    for step, batch in enumerate(dataLoader):
        images, boxes, labels = batch[0], batch[1], batch[2]
        # 只可视化一个batch的图像:
        if step > 0: break
        # 图像均值
        mean = np.array([0.485, 0.456, 0.406]) 
        # 标准差
        std = np.array([[0.229, 0.224, 0.225]]) 
        plt.figure(figsize = (8,8))
        for idx, imgBoxLabel in enumerate(zip(images, boxes, labels)):
            img, box, label = imgBoxLabel
            ax = plt.subplot(4,4,idx+1)
            img = img.numpy().transpose((1,2,0))
            # 由于在数据预处理时我们对数据进行了标准归一化,可视化的时候需要将其还原
            img = img * std + mean
            for instBox, instLabel in zip(box, label):
                x, y, w, h = round(instBox[0]),round(instBox[1]), round(instBox[2]), round(instBox[3])
                # 显示框
                ax.add_patch(plt.Rectangle((x, y), w, h, color='blue', fill=False, linewidth=2))
                # 显示类别
                ax.text(x, y, catName[instLabel], bbox={'facecolor':'white', 'alpha':0.5})
            plt.imshow(img)
            # 在图像上方展示对应的标签
            # 取消坐标轴
            plt.axis("off")
             # 微调行间距
            plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.05, hspace=0.05)
        plt.show()

example

python 复制代码
# for test only:
if __name__ == "__main__":
    # 固定随机种子
    seed = 23
    seed_everything(seed)
    # BatcchSize
    BS = 16
    # 图像尺寸
    imgSize = [800, 800]

    trainAnnPath = "E:/datasets/Universal/COCO2017/annotations/instances_train2017.json"
    testAnnPath = "E:/datasets/Universal/COCO2017/annotations/instances_val2017.json"
    imgDir =  "E:/datasets/Universal/COCO2017/train2017"
    # 自定义数据集读取类
    trainDataset = COCODataset(trainAnnPath, imgDir, imgSize, trainMode=True)
    trainDataLoader = DataLoader(trainDataset, shuffle=True, batch_size = BS, num_workers=2, pin_memory=True,
                                    collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))
    # validDataset = COCODataset(testAnnPath, imgDir, imgSize, trainMode=False)
    # validDataLoader = DataLoader(validDataset, shuffle=True, batch_size = BS, num_workers = 1, pin_memory=True, 
                                  # collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))



    print(f'训练集大小 : {trainDataset.__len__()}')
    visBatch(trainDataLoader)
    for step, batch in enumerate(trainDataLoader):
        images, boxes, labels = batch[0], batch[1], batch[2]
        # torch.Size([bs, 3, 800, 800])
        print(f'images.shape : {images.shape}')   
        # 列表形式,因为每个框里的实例数量不一,所以每个列表里的box数量不一
        print(f'len(boxes) : {len(boxes)}')     
        # 列表形式,因为每个框里的实例数量不一,所以每个列表里的label数量不一  
        print(f'len(labels) : {len(labels)}')     
        break

输出

bash 复制代码
images.shape : torch.Size([16, 3, 800, 800])
len(boxes) : 16
len(labels) : 16
相关推荐
菜就多练_08281 分钟前
《深度学习》OpenCV 摄像头OCR 过程及案例解析
人工智能·深度学习·opencv·ocr
达柳斯·绍达华·宁16 分钟前
CNN中的平移不变性和平移等变性
人工智能·神经网络·cnn
弱冠少年37 分钟前
websockets库使用(基于Python)
开发语言·python·numpy
没有余地 EliasJie1 小时前
Windows Ubuntu下搭建深度学习Pytorch训练框架与转换环境TensorRT
pytorch·windows·深度学习·ubuntu·pycharm·conda·tensorflow
技术无疆1 小时前
【Python】Streamlit:为数据科学与机器学习打造的简易应用框架
开发语言·人工智能·python·深度学习·神经网络·机器学习·数据挖掘
xuehaishijue1 小时前
红外画面空中目标检测系统源码分享
人工智能·目标检测·计算机视觉
羊小猪~~1 小时前
机器学习/数据分析--用通俗语言讲解时间序列自回归(AR)模型,并用其预测天气,拟合度98%+
人工智能·python·机器学习·数据挖掘·数据分析·回归·时序数据库
浊酒南街1 小时前
吴恩达深度学习笔记:卷积神经网络(Foundations of Convolutional Neural Networks)2.7-2.8
人工智能·深度学习·神经网络
qq_273900231 小时前
解析TMalign文本文件中的转换矩阵
python·生物信息学