Pytorch如何加载和读取VOC数据集用来做目标检测?

要用 PyTorch 做 VOC 目标检测,直接用内置的 VOCDetection 类是最省事的方案,但它的输出是原始的 XML 字符串,没法直接喂给模型。为了让你能直接上手训练,我写了一个自定义 Dataset 的完整示例。

这个版本直接把 XML 解析成模型需要的 boxeslabels,拿来就能用:

📦 核心代码:自定义 VOC Dataset

python 复制代码
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import xml.etree.ElementTree as ET

class VOCDetection(Dataset):
    """
    Pascal VOC 数据集类,用于目标检测
    """
    # Pascal VOC 的 20 个类别 (加上背景)
    CLASSES = [
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
        'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
    ]
    
    def __init__(self, root_dir, year='2012', image_set='train', transforms=None):
        """
        Args:
            root_dir: VOCdevkit 的根目录 (例如: './VOCdevkit')
            year: 年份 ('2007' 或 '2012')
            image_set: 'train', 'val', 'trainval'
            transforms: 数据增强操作
        """
        super().__init__()
        self.root_dir = root_dir
        self.year = year
        self.image_set = image_set
        self.transforms = transforms
        
        # 构建数据集路径
        voc_dir = f'VOC{year}'
        self.img_dir = os.path.join(root_dir, voc_dir, 'JPEGImages')
        self.anno_dir = os.path.join(root_dir, voc_dir, 'Annotations')
        split_txt = os.path.join(root_dir, voc_dir, 'ImageSets', 'Main', f'{image_set}.txt')
        
        # 读取图片文件名列表
        with open(split_txt, 'r') as f:
            self.file_names = f.read().strip().split()

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        # 1. 获取文件名
        file_name = self.file_names[idx]
        
        # 2. 加载图片
        img_path = os.path.join(self.img_dir, f'{file_name}.jpg')
        image = Image.open(img_path).convert('RGB')
        
        # 3. 解析标注文件 (XML)
        anno_path = os.path.join(self.anno_dir, f'{file_name}.xml')
        target = self.parse_annotation(anno_path)
        
        # 4. 应用数据增强 (如果有)
        if self.transforms:
            image = self.transforms(image)
        
        return image, target

    def parse_annotation(self, anno_path):
        """
        解析 XML 文件,返回字典格式的标注
        """
        tree = ET.parse(anno_path)
        root = tree.getroot()
        
        boxes = []  # 边界框
        labels = [] # 类别索引 (从 1 开始,0 通常留给背景)
        
        for obj in root.iter('object'):
            # 获取类别名并转换为索引
            cls_name = obj.find('name').text.lower().strip()
            cls_id = self.CLASSES.index(cls_name) + 1 # +1 是因为 0 是背景
            
            # 获取边界框坐标
            xml_box = obj.find('bndbox')
            xmin = float(xml_box.find('xmin').text)
            ymin = float(xml_box.find('ymin').text)
            xmax = float(xml_box.find('xmax').text)
            ymax = float(xml_box.find('ymax').text)
            
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(cls_id)
        
        # 转换为 Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        
        # 构建目标字典 (符合 torchvision 检测模型的格式)
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        # target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # 可选: 计算面积
        # target["iscrowd"] = torch.zeros((len(boxes),), dtype=torch.int64) # 可选: 标记是否为人群
        
        return target

🚀 如何使用这个 Dataset

写好 Dataset 后,只需要配合 DataLoader 就能开始训练了。

python 复制代码
from torch.utils.data import DataLoader
import torchvision.transforms as T

# 1. 定义数据预处理 (目标检测通常只转 Tensor)
# 注意:对于检测任务,ToTensor() 会自动将像素值归一化到 [0,1]
transform = T.Compose([
    T.ToTensor(), 
])

# 2. 实例化数据集
# 假设你的数据放在 './VOCdevkit' 文件夹下
dataset = VOCDetection(
    root_dir='./VOCdevkit', 
    year='2012', 
    image_set='train', 
    transforms=transform
)

# 3. 创建 DataLoader
# 注意:检测任务的 batch_size 通常设为 1,或者使用特殊的 collate_fn
# 这里为了简单,先不设 batch_size > 1 (因为每张图的框数量不同,直接堆叠会报错)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)

# 4. 遍历数据 (模拟训练循环)
for batch_idx, (images, targets) in enumerate(data_loader):
    # images: [1, 3, H, W] 的 Tensor
    # targets: 包含 'boxes' 和 'labels' 的字典 (列表形式,因为 batch_size=1)
    
    print(f"批次 {batch_idx}")
    print(f"图片形状: {images.shape}")
    print(f"第一张图有 {len(targets[0]['boxes'])} 个标注框")
    print(f"类别索引: {targets[0]['labels'].tolist()}")
    
    # 这里就可以把 images 和 targets 送入模型了
    # outputs = model(images, targets)
    
    if batch_idx == 2: # 只演示前 3 个 batch
        break

💡 关键点说明

  • 为什么返回字典?torchvision 里的预训练检测模型(如 Faster R-CNN)要求 targets 是一个包含 boxeslabels 的 Tensor 字典。
  • 坐标格式 :代码中解析出的坐标是 (xmin, ymin, xmax, ymax),这是 PyTorch 官方模型要求的格式。
  • 关于 Batch Size :你会发现上面的代码 batch_size=1。这是因为每张图片的物体数量不同,PyTorch 默认无法将它们堆叠成一个 Tensor。在实际训练中,通常使用 batch_size=1,或者写一个复杂的 collate_fn 来处理变长数据。
相关推荐
测试_AI_一辰2 小时前
Agent & RAG 测试工程05:把 RAG 的检索过程跑清楚:chunk 是什么、怎么来的、怎么被命中的
开发语言·人工智能·功能测试·自动化·ai编程
Henry-SAP2 小时前
SAP(ERP) 组织结构业务视角解析
大数据·人工智能·sap·erp·sap pp
龙腾亚太2 小时前
航空零部件加工变形难题破解:数字孪生 + 深度学习的精度控制实战
人工智能·深度学习·数字孪生·ai工程师·ai证书·转型ai
Coding茶水间2 小时前
基于深度学习的输电电力设备检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
开发语言·人工智能·深度学习·yolo·目标检测·机器学习
是Dream呀2 小时前
基于深度学习的人类活动识别模型研究:HAR-DeepConvLG的设计与应用
人工智能·深度学习
jkyy20142 小时前
健康座舱:健康有益赋能新能源汽车开启移动健康新场景
人工智能·物联网·汽车·健康医疗
冀博2 小时前
从零到一:我如何用 LangChain + 智谱 AI 搭建具备“记忆与手脚”的智能体
人工智能·langchain
AI周红伟2 小时前
周红伟:中国信息通信研究院院长余晓晖关于智算:《算力互联互通行动计划》和《关于深入实施“人工智能+”行动的意见》的意见
人工智能
橘子师兄3 小时前
C++AI大模型接入SDK—ChatSDK封装
开发语言·c++·人工智能·后端