coco.py文件详解

python 复制代码
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
COCO dataset which returns image_id for evaluation.

Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
"""
from pathlib import Path#处理文件和目录路径的模块

import torch
import torch.utils.data#用于创建和操作张量(tensors),以及构建数据加载器用于训练和测试
import torchvision
from pycocotools import mask as coco_mask#处理 COCO 数据集的 Python 工具包
import datasets.transforms as T#自定义的数据预处理和增强操作


class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, img_folder, ann_file, transforms, return_masks):
        super(CocoDetection, self).__init__(img_folder, ann_file)#img_folder 是图像文件夹的路径,ann_file 是 COCO 标注文件的路径
        self._transforms = transforms#transforms 是一系列用于数据预处理和增强的转换操作
        self.prepare = ConvertCocoPolysToMask(return_masks)#return_masks 是一个布尔值,表示是否返回目标的遮罩信息

    '''
    __getitem__(self, idx) 方法重写了父类的同名方法。它用于获取指定索引 idx 对应的图像和目标。
    首先,通过调用父类的 __getitem__ 方法,获取原始的图像和目标数据。
    然后,从 self.ids 中获取对应索引的图像 ID,并将图像 ID 和目标数据组织成一个字典 target: {'image_id': image_id, 'annotations': target}。
    接下来,通过调用 self.prepare() 方法对图像和目标数据进行进一步处理,如将多边形转换为遮罩。
    最后,如果存在数据转换操作 self._transforms,则将图像和目标数据传递给它们进行处理。
    返回处理后的图像和目标数据。
    '''
    def __getitem__(self, idx):
        img, target = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}
        img, target = self.prepare(img, target)
        if self._transforms is not None:
            img, target = self._transforms(img, target)
        return img, target

'''segmentations 是包含多边形分割信息的列表。
height 和 width 分别是目标图像的高度和宽度。
masks 是一个空列表,用于存储遮罩掩膜。
对于每个多边形 polygons,使用 coco_mask.frPyObjects() 函数将其转换为 COCO 格式的 RLE 编码,并且 height 和 width 参数告诉该函数需要生成怎样尺寸的遮罩掩膜。
使用 coco_mask.decode() 函数将 RLE 编码转换为实际的遮罩掩膜。
如果遮罩掩膜的维数小于三,则在最后一维上添加一个新的维度。
将遮罩掩膜转换为 PyTorch 张量类型,并且只保留第二维和第三维的像素信息。在第二维和第三维上,利用 any() 函数对所有像素点进行逻辑或运算,最终将多边形分割信息转换成二值化的遮罩掩膜。
将转换后的遮罩掩膜添加到 masks 列表中。
测试 masks 是否为空列表,如果是,则创建一个全 0 的遮罩掩膜,大小为 (0, height, width)。
最后,通过 torch.stack() 函数将列表中所有遮罩掩膜沿着新的第 0 维进行叠加,得到一个形状为 (N, height, width) 的张量,其中 N 是分割的数量。'''
def convert_coco_poly_to_mask(segmentations, height, width):
    masks = []
    for polygons in segmentations:
        rles = coco_mask.frPyObjects(polygons, height, width)
        mask = coco_mask.decode(rles)
        if len(mask.shape) < 3:
            mask = mask[..., None]
        mask = torch.as_tensor(mask, dtype=torch.uint8)
        mask = mask.any(dim=2)
        masks.append(mask)
    if masks:
        masks = torch.stack(masks, dim=0)
    else:
        masks = torch.zeros((0, height, width), dtype=torch.uint8)
    return masks

'''__init__ 方法用于初始化 ConvertCocoPolysToMask 类的实例。它接受一个参数 return_masks,默认为 False。该参数用于指定是否返回遮罩掩膜。

__call__ 方法是类的可调用方法,在实例被调用时会执行。它接受两个参数 image 和 target,分别表示输入的图像和目标。

首先,获取图像的宽度和高度,并将其保存为变量 w 和 h。
获取目标的图像ID,并将其转换为PyTorch张量类型。
从目标中提取注释信息 anno。
过滤掉包含 iscrowd 属性的注释对象或 iscrowd 值为0的注释对象。
提取注释对象的边界框,并将其转换为PyTorch张量类型。然后对边界框进行归一化处理(从绝对坐标转换为相对坐标)。同时,将边界框的坐标限制在图像边界内。
提取注释对象的类别标签,并将其转换为PyTorch张量类型。
如果设置了 return_masks 为 True,则提取注释对象的多边形分割信息,并调用 convert_coco_poly_to_mask 函数将分割信息转换为遮罩掩膜。
检查是否存在关键点信息,并将其提取为PyTorch张量类型。
过滤掉无效的边界框,即宽度和高度小于等于0的边界框。
将过滤后的边界框、类别标签、遮罩掩膜(如果设置了 return_masks)、关键点信息存储到 target 字典中。
提取注释对象的区域面积和 iscrowd 属性,并将其存储到 target 字典中。
存储原始图像的尺寸信息和当前图像的尺寸信息到 target 字典中。
返回图像和 target 字典作为输出。'''
class ConvertCocoPolysToMask(object):
    def __init__(self, return_masks=False):
        self.return_masks = return_masks

    def __call__(self, image, target):
        w, h = image.size

        image_id = target["image_id"]
        image_id = torch.tensor([image_id])

        anno = target["annotations"]

        anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]

        boxes = [obj["bbox"] for obj in anno]
        # guard against no boxes via resizing
        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
        boxes[:, 2:] += boxes[:, :2]
        boxes[:, 0::2].clamp_(min=0, max=w)
        boxes[:, 1::2].clamp_(min=0, max=h)

        classes = [obj["category_id"] for obj in anno]
        classes = torch.tensor(classes, dtype=torch.int64)

        if self.return_masks:
            segmentations = [obj["segmentation"] for obj in anno]
            masks = convert_coco_poly_to_mask(segmentations, h, w)

        keypoints = None
        if anno and "keypoints" in anno[0]:
            keypoints = [obj["keypoints"] for obj in anno]
            keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
            num_keypoints = keypoints.shape[0]
            if num_keypoints:
                keypoints = keypoints.view(num_keypoints, -1, 3)

        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
        boxes = boxes[keep]
        classes = classes[keep]
        if self.return_masks:
            masks = masks[keep]
        if keypoints is not None:
            keypoints = keypoints[keep]

        target = {}
        target["boxes"] = boxes
        target["labels"] = classes
        if self.return_masks:
            target["masks"] = masks
        target["image_id"] = image_id
        if keypoints is not None:
            target["keypoints"] = keypoints

        # for conversion to coco api
        area = torch.tensor([obj["area"] for obj in anno])
        iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
        target["area"] = area[keep]
        target["iscrowd"] = iscrowd[keep]

        target["orig_size"] = torch.as_tensor([int(h), int(w)])
        target["size"] = torch.as_tensor([int(h), int(w)])

        return image, target

'''首先创建了一个 normalize 的转换操作,它将图像转换为张量并进行标准化。具体来说,它使用了均值 [0.485, 0.456, 0.406] 和标准差 [0.229, 0.224, 0.225] 进行标准化。
定义了一系列尺度 scales,用于对图像进行随机调整大小。
如果 image_set 的值为 'train',则返回一系列的转换操作:
随机水平翻转图像。
随机选择以下两种转换操作之一:
将图像随机调整为 scales 中的某个尺度,并保证最大边长不超过1333像素。
先随机调整图像的短边长度为 [400, 500, 600] 中的某个值,然后随机裁剪出大小为 [384, 600] 的图像,并将其随机调整为 scales 中的某个尺度,并保证最大边长不超过1333像素。
对图像进行归一化操作。
如果 image_set 的值为 'val',则返回一系列的转换操作:
将图像随机调整为 [800] 中的某个尺度,并保证最大边长不超过1333像素。
对图像进行归一化操作。
如果 image_set 的值不是 'train' 或 'val',则抛出一个异常,表示未知的 image_set 值。'''
def make_coco_transforms(image_set):

    normalize = T.Compose([
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]

    if image_set == 'train':
        return T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomSelect(
                T.RandomResize(scales, max_size=1333),
                T.Compose([
                    T.RandomResize([400, 500, 600]),
                    T.RandomSizeCrop(384, 600),
                    T.RandomResize(scales, max_size=1333),
                ])
            ),
            normalize,
        ])

    if image_set == 'val':
        return T.Compose([
            T.RandomResize([800], max_size=1333),
            normalize,
        ])

    raise ValueError(f'unknown {image_set}')

'''首先,根据传入的 args.coco_path 构建 COCO 数据集的根路径 root。
然后,对 root 进行存在性检查,如果该路径不存在,则抛出异常。
定义了一个变量 mode,其值为 'instances',表示数据集的模式。
定义了一个字典 PATHS,其中包含了不同 image_set 对应的图像文件夹路径和注释文件路径。
根据传入的 image_set 从 PATHS 字典中获取对应的图像文件夹路径和注释文件路径,并分别赋值给 img_folder 和 ann_file 变量。
调用 CocoDetection 类创建一个 COCO 数据集对象 dataset。CocoDetection 是一个用于处理 COCO 数据集的类,它接收图像文件夹路径、注释文件路径、转换操作和返回掩码选项作为参数。
最后,返回创建的 COCO 数据集对象 dataset。'''
def build(image_set, args):
    root = Path(args.coco_path)
    assert root.exists(), f'provided COCO path {root} does not exist'
    mode = 'instances'
    PATHS = {
        "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
        "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
    }

    img_folder, ann_file = PATHS[image_set]
    dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks)
    return dataset
相关推荐
Kacey Huang42 分钟前
YOLOv1、YOLOv2、YOLOv3目标检测算法原理与实战第十三天|YOLOv3实战、安装Typora
人工智能·算法·yolo·目标检测·计算机视觉
日日行不惧千万里5 小时前
如何用YOLOv8训练一个识别安全帽的模型?
python·yolo
Coovally AI模型快速验证14 小时前
MMYOLO:打破单一模式限制,多模态目标检测的革命性突破!
人工智能·算法·yolo·目标检测·机器学习·计算机视觉·目标跟踪
红色的山茶花1 天前
YOLOv10-1.1部分代码阅读笔记-predictor.py
笔记·深度学习·yolo
AI街潜水的八角2 天前
工业缺陷检测实战——基于深度学习YOLOv10神经网络PCB缺陷检测系统
pytorch·深度学习·yolo
金色旭光2 天前
目标检测高频评价指标的计算过程
算法·yolo
AI街潜水的八角2 天前
PyTorch框架——基于深度学习YOLOv8神经网络学生课堂行为检测识别系统
pytorch·深度学习·yolo
Hugh&3 天前
(开源)基于Django+Yolov8+Tensorflow的智能鸟类识别平台
python·yolo·django·tensorflow
天天代码码天天3 天前
C# OpenCvSharp 部署读光-票证检测矫正模型(cv_resnet18_card_correction)
人工智能·深度学习·yolo·目标检测·计算机视觉·c#·票证检测矫正
前网易架构师-高司机3 天前
行人识别检测数据集,yolo格式,PASICAL VOC XML,COCO JSON,darknet等格式的标注都支持,准确识别率可达99.5%
xml·yolo·行人检测数据集