Pytorch如何加载和读取VOC数据集用来做目标检测?

要用 PyTorch 做 VOC 目标检测,直接用内置的 VOCDetection 类是最省事的方案,但它的输出是原始的 XML 字符串,没法直接喂给模型。为了让你能直接上手训练,我写了一个自定义 Dataset 的完整示例。

这个版本直接把 XML 解析成模型需要的 boxeslabels,拿来就能用:

📦 核心代码:自定义 VOC Dataset

python 复制代码
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import xml.etree.ElementTree as ET

class VOCDetection(Dataset):
    """
    Pascal VOC 数据集类,用于目标检测
    """
    # Pascal VOC 的 20 个类别 (加上背景)
    CLASSES = [
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
        'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
    ]
    
    def __init__(self, root_dir, year='2012', image_set='train', transforms=None):
        """
        Args:
            root_dir: VOCdevkit 的根目录 (例如: './VOCdevkit')
            year: 年份 ('2007' 或 '2012')
            image_set: 'train', 'val', 'trainval'
            transforms: 数据增强操作
        """
        super().__init__()
        self.root_dir = root_dir
        self.year = year
        self.image_set = image_set
        self.transforms = transforms
        
        # 构建数据集路径
        voc_dir = f'VOC{year}'
        self.img_dir = os.path.join(root_dir, voc_dir, 'JPEGImages')
        self.anno_dir = os.path.join(root_dir, voc_dir, 'Annotations')
        split_txt = os.path.join(root_dir, voc_dir, 'ImageSets', 'Main', f'{image_set}.txt')
        
        # 读取图片文件名列表
        with open(split_txt, 'r') as f:
            self.file_names = f.read().strip().split()

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        # 1. 获取文件名
        file_name = self.file_names[idx]
        
        # 2. 加载图片
        img_path = os.path.join(self.img_dir, f'{file_name}.jpg')
        image = Image.open(img_path).convert('RGB')
        
        # 3. 解析标注文件 (XML)
        anno_path = os.path.join(self.anno_dir, f'{file_name}.xml')
        target = self.parse_annotation(anno_path)
        
        # 4. 应用数据增强 (如果有)
        if self.transforms:
            image = self.transforms(image)
        
        return image, target

    def parse_annotation(self, anno_path):
        """
        解析 XML 文件,返回字典格式的标注
        """
        tree = ET.parse(anno_path)
        root = tree.getroot()
        
        boxes = []  # 边界框
        labels = [] # 类别索引 (从 1 开始,0 通常留给背景)
        
        for obj in root.iter('object'):
            # 获取类别名并转换为索引
            cls_name = obj.find('name').text.lower().strip()
            cls_id = self.CLASSES.index(cls_name) + 1 # +1 是因为 0 是背景
            
            # 获取边界框坐标
            xml_box = obj.find('bndbox')
            xmin = float(xml_box.find('xmin').text)
            ymin = float(xml_box.find('ymin').text)
            xmax = float(xml_box.find('xmax').text)
            ymax = float(xml_box.find('ymax').text)
            
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(cls_id)
        
        # 转换为 Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        
        # 构建目标字典 (符合 torchvision 检测模型的格式)
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        # target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # 可选: 计算面积
        # target["iscrowd"] = torch.zeros((len(boxes),), dtype=torch.int64) # 可选: 标记是否为人群
        
        return target

🚀 如何使用这个 Dataset

写好 Dataset 后,只需要配合 DataLoader 就能开始训练了。

python 复制代码
from torch.utils.data import DataLoader
import torchvision.transforms as T

# 1. 定义数据预处理 (目标检测通常只转 Tensor)
# 注意:对于检测任务,ToTensor() 会自动将像素值归一化到 [0,1]
transform = T.Compose([
    T.ToTensor(), 
])

# 2. 实例化数据集
# 假设你的数据放在 './VOCdevkit' 文件夹下
dataset = VOCDetection(
    root_dir='./VOCdevkit', 
    year='2012', 
    image_set='train', 
    transforms=transform
)

# 3. 创建 DataLoader
# 注意:检测任务的 batch_size 通常设为 1,或者使用特殊的 collate_fn
# 这里为了简单,先不设 batch_size > 1 (因为每张图的框数量不同,直接堆叠会报错)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)

# 4. 遍历数据 (模拟训练循环)
for batch_idx, (images, targets) in enumerate(data_loader):
    # images: [1, 3, H, W] 的 Tensor
    # targets: 包含 'boxes' 和 'labels' 的字典 (列表形式,因为 batch_size=1)
    
    print(f"批次 {batch_idx}")
    print(f"图片形状: {images.shape}")
    print(f"第一张图有 {len(targets[0]['boxes'])} 个标注框")
    print(f"类别索引: {targets[0]['labels'].tolist()}")
    
    # 这里就可以把 images 和 targets 送入模型了
    # outputs = model(images, targets)
    
    if batch_idx == 2: # 只演示前 3 个 batch
        break

💡 关键点说明

  • 为什么返回字典?torchvision 里的预训练检测模型(如 Faster R-CNN)要求 targets 是一个包含 boxeslabels 的 Tensor 字典。
  • 坐标格式 :代码中解析出的坐标是 (xmin, ymin, xmax, ymax),这是 PyTorch 官方模型要求的格式。
  • 关于 Batch Size :你会发现上面的代码 batch_size=1。这是因为每张图片的物体数量不同,PyTorch 默认无法将它们堆叠成一个 Tensor。在实际训练中,通常使用 batch_size=1,或者写一个复杂的 collate_fn 来处理变长数据。
相关推荐
智算菩萨10 小时前
可验证奖励强化学习(RLVR):如何让大模型更可靠?
人工智能·机器学习
YZ09910 小时前
Sora2 AI视频去水印接口
人工智能·音视频·api·ai编程
AI周红伟10 小时前
周红伟:Sglang+Vllm+Qwen3.5企业级部署案例实操
大数据·人工智能·大模型·智能体
Niuguangshuo11 小时前
深度学习:激活函数大全
人工智能·深度学习
人机与认知实验室11 小时前
2028年春晚,会出现机器人主持人吗?
人工智能·机器人
java1234_小锋11 小时前
嵌入模型与Chroma向量数据库 - Qwen3嵌入模型使用 - AI大模型应用开发必备知识
人工智能·向量数据库·chroma
沪漂阿龙11 小时前
大模型如何突破上下文窗口?RoPE、ALiBi与长文本扩展全解析
人工智能
witAI11 小时前
**AI仿真人剧生成软件2025推荐,解锁沉浸式数字内容创作
人工智能·python·量子计算
DoogalStudio11 小时前
DevMind插件设计方案产品需求文档
人工智能·笔记