Lnton羚通关于【PyTorch】教程:torchvision 目标检测微调

torchvision 目标检测微调

本教程将使用Penn-Fudan Database for Pedestrian Detection and Segmentation 微调 预训练的Mask R-CNN 模型。 它包含 170 张图片,345 个行人实例。

定义数据集

用于训练目标检测、实例分割和人物关键点检测的参考脚本允许轻松支持添加新的自定义数据集。数据集应继承自标准的 torch.utils.data.dataset 类,并实现 lengetitem

getitem 需要返回:

image: PIL 图像 (H, W)

target: 字典数据,需要包含字段

boxes (FloatTensor[N, 4]): N 个 Bounding box 的位置坐标 [x0, y0, x1, y1], 0~W, 0~H

labels (Int64Tensor[N]): 每个 Bounding box 的类别标签,0 代表背景类。

image_id (Int64Tensor[1]): 图像的标签 id,在数据集中是唯一的。

area (Tensor[N]): Bounding box 的面积,在 COCO 度量里使用,可以分别对不同大小的目标进行度量。

iscrowd (UInt8Tensor[N]): 如果 iscrowd=True 在评估时忽略。

(optionally) masks (UInt8Tensor[N, H, W]): 可选的 分割掩码

(optionally) keypoints (FloatTensor[N, K, 3]): 对于 N 个目标来说,包含 K 个关键点 [x, y, visibility], visibility=0 表示关键点不可见。

如果模型可以返回上述方法,可以在训练、评估都能使用,可以用 pycocotools 里的脚本进行评估。

pip install pycocotools 安装工具。

关于 labels 有个说明,模型默认 0 为背景。如果数据集没有背景类别,不需要在标签里添加 0 。 例如,假设有 cat 和 dog 两类,定义了 1 表示 cat , 2 表示 dog , 如果一个图像有两个类别,类别的 tensor 为 [1, 2] 。

此外,如果希望在训练时使用纵横比分组,那么建议实现 get_height_and_width 方法,该方法将返回图像的高度和宽度,如果未提供此方法,我们将通过 getitem 查询数据集的所有元素,这会将图像加载到内存中,并且比提供自定义方法的速度慢。

为 PennFudan 写自定义数据集

文件夹结构如下:

PennFudanPed/
  PedMasks/
    FudanPed00001_mask.png
    FudanPed00002_mask.png
    FudanPed00003_mask.png
    FudanPed00004_mask.png
    ...
  PNGImages/
    FudanPed00001.png
    FudanPed00002.png
    FudanPed00003.png
    FudanPed00004.png

这是图像的标注信息,包含了 mask 以及 bounding box 。每个图像都有对应的分割掩码,每个颜色代表不同的实例。

import os 
import numpy as np 
import torch 
from PIL import Image

class PennFudanDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
    
        ## 加载所有图像,sort 保证他们能够对应起来
        self.images = list(sorted(os.listdir(os.path.join(self.root, 'PNGImages'))))
        self.masks = list(sorted(os.listdir(os.path.join(self.root, 'PedMasks'))))
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root, 'PNGImages', self.images[idx])
        mask_path = os.path.join(self.root, 'PedMasks', self.masks[idx])
        image = Image.open(img_path).convert('RGB')
    
        ## mask 图像并没有转换为 RGB,里面存储的是标签,0表示的是背景
        mask = Image.open(mask_path)
    
        # 转换为 numpy
        mask = np.array(mask) 
    
        # 实例解码成不同的颜色
        obj_ids = np.unique(mask)
    
        # 移除背景
        obj_ids = obj_ids[1:]
    
        masks = mask == obj_ids[:, None, None]
    
        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])
        
        # 转换为 tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
    
        target = {}
    
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
    
        if self.transforms is not None:
            image, target = self.transforms(image, target)
    
        return image, target
  
    def __len__(self):
        return len(self.images)

Lnton羚通专注于音视频算法、算力、云平台的高科技人工智能企业。 公司基于视频分析技术、视频智能传输技术、远程监测技术以及智能语音融合技术等, 拥有多款可支持ONVIF、RTSP、GB/T28181等多协议、多路数的音视频智能分析服务器/云平台。

相关推荐
管二狗赶快去工作!4 分钟前
体系结构论文(五十四):Reliability-Aware Runahead 【22‘ HPCA】
人工智能·神经网络·dnn·体系结构·实时系统
AI绘画君13 分钟前
Stable Diffusion绘画 | AI 图片智能扩充,超越PS扩图的AI扩图功能(附安装包)
人工智能·ai作画·stable diffusion·aigc·ai绘画·ai扩图
AAI机器之心15 分钟前
LLM大模型:开源RAG框架汇总
人工智能·chatgpt·开源·大模型·llm·大语言模型·rag
Evand J36 分钟前
物联网智能设备:未来生活的变革者
人工智能·物联网·智能手机·智能家居·智能手表
HyperAI超神经1 小时前
Meta 首个多模态大模型一键启动!首个多针刺绣数据集上线,含超 30k 张图片
大数据·人工智能·深度学习·机器学习·语言模型·大模型·数据集
sp_fyf_20241 小时前
计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-03
人工智能·算法·机器学习·计算机视觉·语言模型·自然语言处理
新缸中之脑1 小时前
10个令人惊叹的AI工具
人工智能
学步_技术1 小时前
自动驾驶系列—线控悬架技术:自动驾驶背后的动力学掌控者
人工智能·机器学习·自动驾驶·线控系统·悬挂系统
Eric.Lee20211 小时前
数据集-目标检测系列- 螃蟹 检测数据集 crab >> DataBall
python·深度学习·算法·目标检测·计算机视觉·数据集·螃蟹检测
DogDaoDao2 小时前
【预备理论知识——2】深度学习:线性代数概述
人工智能·深度学习·线性代数