要用 PyTorch 做 VOC 目标检测,直接用内置的 VOCDetection 类是最省事的方案,但它的输出是原始的 XML 字符串,没法直接喂给模型。为了让你能直接上手训练,我写了一个自定义 Dataset 的完整示例。
这个版本直接把 XML 解析成模型需要的 boxes 和 labels,拿来就能用:
📦 核心代码:自定义 VOC Dataset
python
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import xml.etree.ElementTree as ET
class VOCDetection(Dataset):
"""
Pascal VOC 数据集类,用于目标检测
"""
# Pascal VOC 的 20 个类别 (加上背景)
CLASSES = [
'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]
def __init__(self, root_dir, year='2012', image_set='train', transforms=None):
"""
Args:
root_dir: VOCdevkit 的根目录 (例如: './VOCdevkit')
year: 年份 ('2007' 或 '2012')
image_set: 'train', 'val', 'trainval'
transforms: 数据增强操作
"""
super().__init__()
self.root_dir = root_dir
self.year = year
self.image_set = image_set
self.transforms = transforms
# 构建数据集路径
voc_dir = f'VOC{year}'
self.img_dir = os.path.join(root_dir, voc_dir, 'JPEGImages')
self.anno_dir = os.path.join(root_dir, voc_dir, 'Annotations')
split_txt = os.path.join(root_dir, voc_dir, 'ImageSets', 'Main', f'{image_set}.txt')
# 读取图片文件名列表
with open(split_txt, 'r') as f:
self.file_names = f.read().strip().split()
def __len__(self):
return len(self.file_names)
def __getitem__(self, idx):
# 1. 获取文件名
file_name = self.file_names[idx]
# 2. 加载图片
img_path = os.path.join(self.img_dir, f'{file_name}.jpg')
image = Image.open(img_path).convert('RGB')
# 3. 解析标注文件 (XML)
anno_path = os.path.join(self.anno_dir, f'{file_name}.xml')
target = self.parse_annotation(anno_path)
# 4. 应用数据增强 (如果有)
if self.transforms:
image = self.transforms(image)
return image, target
def parse_annotation(self, anno_path):
"""
解析 XML 文件,返回字典格式的标注
"""
tree = ET.parse(anno_path)
root = tree.getroot()
boxes = [] # 边界框
labels = [] # 类别索引 (从 1 开始,0 通常留给背景)
for obj in root.iter('object'):
# 获取类别名并转换为索引
cls_name = obj.find('name').text.lower().strip()
cls_id = self.CLASSES.index(cls_name) + 1 # +1 是因为 0 是背景
# 获取边界框坐标
xml_box = obj.find('bndbox')
xmin = float(xml_box.find('xmin').text)
ymin = float(xml_box.find('ymin').text)
xmax = float(xml_box.find('xmax').text)
ymax = float(xml_box.find('ymax').text)
boxes.append([xmin, ymin, xmax, ymax])
labels.append(cls_id)
# 转换为 Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
# 构建目标字典 (符合 torchvision 检测模型的格式)
target = {}
target["boxes"] = boxes
target["labels"] = labels
# target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # 可选: 计算面积
# target["iscrowd"] = torch.zeros((len(boxes),), dtype=torch.int64) # 可选: 标记是否为人群
return target
🚀 如何使用这个 Dataset
写好 Dataset 后,只需要配合 DataLoader 就能开始训练了。
python
from torch.utils.data import DataLoader
import torchvision.transforms as T
# 1. 定义数据预处理 (目标检测通常只转 Tensor)
# 注意:对于检测任务,ToTensor() 会自动将像素值归一化到 [0,1]
transform = T.Compose([
T.ToTensor(),
])
# 2. 实例化数据集
# 假设你的数据放在 './VOCdevkit' 文件夹下
dataset = VOCDetection(
root_dir='./VOCdevkit',
year='2012',
image_set='train',
transforms=transform
)
# 3. 创建 DataLoader
# 注意:检测任务的 batch_size 通常设为 1,或者使用特殊的 collate_fn
# 这里为了简单,先不设 batch_size > 1 (因为每张图的框数量不同,直接堆叠会报错)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
# 4. 遍历数据 (模拟训练循环)
for batch_idx, (images, targets) in enumerate(data_loader):
# images: [1, 3, H, W] 的 Tensor
# targets: 包含 'boxes' 和 'labels' 的字典 (列表形式,因为 batch_size=1)
print(f"批次 {batch_idx}")
print(f"图片形状: {images.shape}")
print(f"第一张图有 {len(targets[0]['boxes'])} 个标注框")
print(f"类别索引: {targets[0]['labels'].tolist()}")
# 这里就可以把 images 和 targets 送入模型了
# outputs = model(images, targets)
if batch_idx == 2: # 只演示前 3 个 batch
break
💡 关键点说明
- 为什么返回字典? :
torchvision里的预训练检测模型(如 Faster R-CNN)要求targets是一个包含boxes和labels的 Tensor 字典。 - 坐标格式 :代码中解析出的坐标是
(xmin, ymin, xmax, ymax),这是 PyTorch 官方模型要求的格式。 - 关于 Batch Size :你会发现上面的代码
batch_size=1。这是因为每张图片的物体数量不同,PyTorch 默认无法将它们堆叠成一个 Tensor。在实际训练中,通常使用batch_size=1,或者写一个复杂的collate_fn来处理变长数据。