复现YOLO v1 PyTorch

复现YOLO v1 PyTorch

Paper: [1506.02640] You Only Look Once: Unified, Real-Time Object Detection (arxiv.org)

Github: EclipseR33/yolo_v1_pytorch (github.com)

数据集

VOC2007:The PASCAL Visual Object Classes Challenge 2007 (VOC2007)

VOC2012:The PASCAL Visual Object Classes Challenge 2012 (VOC2012)

PASCAL VOC 07/12的目录结构都是一致的,因此只需要针对VOC07编写代码再扩展即可。VOC2007目录下有5个文件夹。我们需要其中的'Annotations'(存有标注信息),'ImageSets'(存有train、val、test各类文件名), 'JPEGImages'(存有图像)。VOC中的图像都是.jpg文件,ImageSets中的文件都是.txt文件,Annotations中的注释都是.xml文件。

xml 复制代码
//这是一个xml注释的示例,我们需要其中的<object>信息
<annotation>
	<folder>VOC2007</folder>
	<filename>000001.jpg</filename>
	<source>
		<database>The VOC2007 Database</database>
		<annotation>PASCAL VOC2007</annotation>
		<image>flickr</image>
		<flickrid>341012865</flickrid>
	</source>
	<owner>
		<flickrid>Fried Camels</flickrid>
		<name>Jinky the Fruit Bat</name>
	</owner>
	<size>
		<width>353</width>
		<height>500</height>
		<depth>3</depth>
	</size>
	<segmented>0</segmented>
	<object>
		<name>dog</name>
		<pose>Left</pose>
		<truncated>1</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>48</xmin>
			<ymin>240</ymin>
			<xmax>195</xmax>
			<ymax>371</ymax>
		</bndbox>
	</object>
	<object>
		<name>person</name>		// name中包含的就是class信息
		<pose>Left</pose>
		<truncated>1</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>8</xmin>
			<ymin>12</ymin>
			<xmax>352</xmax>
			<ymax>498</ymax>
		</bndbox>
	</object>
</annotation>
find_classes.py

首先我们需要获得VOC数据集中的所有class信息并为其编号,将该信息存储到json文件中

python 复制代码
# 路径: ./dataset/find_classes.py
import xml.etree.ElementTree as ET
from tqdm import tqdm
import json
import os


def xml2dict(xml):
    """
    使用递归读取xml文件
    若c指向的元素是<name>person</name>,那么c.tag是name,c.text则是person
    """
    # data初始化时就已经将所有子元素的tag定义为key
    data = {c.tag: None for c in xml}
    for c in xml:
        # add函数用于将tag与text添加到data中
        def add(data, tag, text):
            if data[tag] is None:
                # data中该tag为空则直接添加text
                data[tag] = text
            elif isinstance(data[tag], list):
                # data中该tag为不为空且已经创建了list则append(text)
                data[tag].append(text)
            else:
                # data中该tag不为空但是没有创建list,需要先创建list
                data[tag] = [data[tag], text]
            return data

        if len(c) == 0:
            # len(c)表示c的子元素个数,若为0则表示c是叶元素,没有子元素
            data = add(data, c.tag, c.text)
        else:
            data = add(data, c.tag, xml2dict(c))
    return data


json_path = './classes.json'	# json保存到的地址

root = r'F:\AI\Dataset\VOC2012\VOCdevkit\VOC2012'	# 数据集root(VOC2007与VOC2012的class信息一致)
# 获取所有xml注释地址
annotation_root = os.path.join(root, 'Annotations')		
annotation_list = os.listdir(annotation_root)
annotation_list = [os.path.join(annotation_root, a) for a in annotation_list]

s = set()
for annotation in tqdm(annotation_list):
    xml = ET.parse(os.path.join(annotation)).getroot()
    data = xml2dict(xml)['object']
    if isinstance(data, list):
        # 有多个object
        for d in data:
            s.add(d['name'])
    else:
        # 仅有一个object
        s.add(data['name'])

s = list(s)
s.sort()

# 以class名称为key可以便于xml2label的转换
data = {value: i for i, value in enumerate(s)}
json_str = json.dumps(data)

with open(json_path, 'w') as f:
    f.write(json_str)

运行./dataset/find_classes.py之后,我们在指定目录下得到一个json文件

json 复制代码
{"aeroplane": 0, "bicycle": 1, "bird": 2, "boat": 3, "bottle": 4, "bus": 5, "car": 6, "cat": 7, "chair": 8, "cow": 9, "diningtable": 10, "dog": 11, "horse": 12, "motorbike": 13, "person": 14, "pottedplant": 15, "sheep": 16, "sofa": 17, "train": 18, "tvmonitor": 19}

接下来我们要开始写data.py文件,主要流程:

1.通过ImageSets中train.txt, val.txt, test.txt文件的指引寻找数据集对应的所有图像名与其地址。

2.编写getitem,读取xml中的信息并转化为label形式,读取图像,并将两者都转为tensor

这里设置的dataset传出的label都是xmin, ymin, xmax, ymax, class的VOC形式,并且是直接的坐标数值,并没有使用百分比表示。

data.py
python 复制代码
# 路径: ./dataset/data.py
from dataset.transform import *		# 导入我们重写的transform类

from torch.utils.data import Dataset
import xml.etree.ElementTree as ET
from PIL import Image
import numpy as np
import json
import os


def get_file_name(root, layout_txt):
    with open(os.path.join(root, layout_txt)) as layout_txt:
        """
        .read()			读取文件中的数据,会得到一个str字符串
        .split('\n')	以\n回车符为分界将str字符串分割成list
        [:-1]			去除最后一个空字符串,文件末尾有\n,分割后会有空字符串所以要去除
        """
        file_name = layout_txt.read().split('\n')[:-1]
    return file_name


def xml2dict(xml):
    # 这里的xml2dict与上一个文件的xml2dict一致
    data = {c.tag: None for c in xml}
    for c in xml:
        def add(data, tag, text):
            if data[tag] is None:
                data[tag] = text
            elif isinstance(data[tag], list):
                data[tag].append(text)
            else:
                data[tag] = [data[tag], text]
            return data

        if len(c) == 0:
            data = add(data, c.tag, c.text)
        else:
            data = add(data, c.tag, xml2dict(c))
    return data


class VOC0712Dataset(Dataset):
    def __init__(self, root, class_path, transforms, mode, data_range=None, get_info=False):
        # label: xmin, ymin, xmax, ymax, class
		
        # 从json文件中获得class的信息
        with open(class_path, 'r') as f:
            json_str = f.read()
            self.classes = json.loads(json_str)
           
        """
        如果是train模式,那么root的输入将为一个list(长为2,分别为2007、2012两年的数据集根目录, main中的root0712是一个示例)。将两个root与train、val两种分割组合成四个layout_txt路径,这四个路径指向VOC07/12的所有可用训练数据。
        如果是test模式那么只有VOC2007的test分割可用。这里也转换为list形式,就可以同一两种模式的代码。
        """
        layout_txt = None
        if mode == 'train':
            root = [root[0], root[0], root[1], root[1]]
            layout_txt = [r'ImageSets\Main\train.txt', r'ImageSets\Main\val.txt',
                          r'ImageSets\Main\train.txt', r'ImageSets\Main\val.txt']
        elif mode == 'test':
            if not isinstance(root, list):
                root = [root]
            layout_txt = [r'ImageSets\Main\test.txt']
        assert layout_txt is not None, 'Unknown mode'
        

        self.transforms = transforms
        self.get_info = get_info	# get_info表示在getitem时是否需要获得图像的名称以及图像大小信息 bool
		
        # 由于有多root,所以image_list与annotation_list均存储了图像与xml文件的绝对路径
        self.image_list = []
        self.annotation_list = []
        for r, txt in zip(root, layout_txt):
            self.image_list += [os.path.join(r, 'JPEGImages', t + '.jpg') for t in get_file_name(r, txt)]
            self.annotation_list += [os.path.join(r, 'Annotations', t + '.xml') for t in get_file_name(r, txt)]
		
        # data_range是一个二元tuple,分别表示数据集需要取哪一段区间,训练时若使用全部的数据则无需传入data_range,默认None的取值是会选择所有的数据的
        if data_range is not None:
            self.image_list = self.image_list[data_range[0]: data_range[1]]
            self.annotation_list = self.annotation_list[data_range[0]: data_range[1]]

    def __len__(self):
        # 返回数据集长度
        return len(self.annotation_list)

    def __getitem__(self, idx):
        image = Image.open(self.image_list[idx])
        image_size = image.size
        label = self.label_process(self.annotation_list[idx])

        if self.transforms is not None:
            # 由于目标检测中image的变换如随机裁剪与Resize都会导致label的变化,所以需要重写transform,添加部分的label处理代码
            image, label = self.transforms(image, label)
        if self.get_info:
            return image, label, os.path.basename(self.image_list[idx]).split('.')[0], image_size
        else:
            return image, label

    def label_process(self, annotation):
        xml = ET.parse(os.path.join(annotation)).getroot()
        data = xml2dict(xml)['object']
        # 根据data的两种形式将其读取到label中,并将label转为numpy形式
        if isinstance(data, list):
            label = [[float(d['bndbox']['xmin']), float(d['bndbox']['ymin']),
                     float(d['bndbox']['xmax']), float(d['bndbox']['ymax']),
                     self.classes[d['name']]]
                     for d in data]
        else:
            label = [[float(data['bndbox']['xmin']), float(data['bndbox']['ymin']),
                     float(data['bndbox']['xmax']), float(data['bndbox']['ymax']),
                     self.classes[data['name']]]]
        label = np.array(label)
        return label


if __name__ == "__main__":
    from dataset.draw_bbox import draw

    root0712 = [r'F:\AI\Dataset\VOC2007\VOCdevkit\VOC2007', r'F:\AI\Dataset\VOC2012\VOCdevkit\VOC2012']

    transforms = Compose([
        ToTensor(),
        RandomHorizontalFlip(0.5),
        Resize(448)
    ])
    ds = VOC0712Dataset(root0712, 'classes.json', transforms, 'train', get_info=True)
    print(len(ds))
    for i, (image, label, image_name, image_size) in enumerate(ds):
        if i <= 1000:
            continue
        elif i >= 1010:
            break
        else:
            print(label.dtype)
            print(tuple(image.size()[1:]))
            draw(image, label, ds.classes)
    print('VOC2007Dataset')
transform.py
python 复制代码
# 路径: ./dataset/transform.py
import torch
import torchvision
import random


class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, label):
        for t in self.transforms:
            image, label = t(image, label)
        return image, label


class ToTensor:
    def __init__(self):
        self.totensor = torchvision.transforms.ToTensor()

    def __call__(self, image, label):
        image = self.totensor(image)
        label = torch.tensor(label)
        return image, label


class RandomHorizontalFlip:
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, image, label):
        """
        :param label: xmin, ymin, xmax, ymax
        如果图片被水平翻转,那么label的xmin与xmax会互换,变成 xmax, ymin, xmin, ymax
        由于YOLO的输出是(center_x, center_y, w, h) ,因此label的xmin与xmax换位不会影响损失计算与训练但是需要注意w,h计算时使用abs
        """
        if random.random() < self.p:
            height, width = image.shape[-2:]
            image = image.flip(-1)      # 水平翻转
            bbox = label[:, :4]
            bbox[:, [0, 2]] = width - bbox[:, [0, 2]]
            label[:, :4] = bbox
        return image, label


class Resize:
    def __init__(self, image_size, keep_ratio=True):
        """
        :param image_size: int
        keep_ratio = True  保留宽高比
        keep_ratio = False 填充成正方形
        """
        self.image_size = image_size
        self.keep_ratio = keep_ratio

    def __call__(self, image, label):
        """
        :param in_image: tensor [3, h, w]
        :param label: xmin, ymin, xmax, ymax
        :return:
        """
        # 将所有图片左上角对齐构成448*448tensor的Transform

        h, w = tuple(image.size()[1:])
        label[:, [0, 2]] = label[:, [0, 2]] / w
        label[:, [1, 3]] = label[:, [1, 3]] / h

        if self.keep_ratio:
            r_h = min(self.image_size / h, self.image_size / w)
            r_w = r_h
        else:
            r_h = self.image_size / h
            r_w = self.image_size / w

        h, w = int(r_h * h), int(r_w * w)
        h, w = min(h, self.image_size), min(w, self.image_size)
        label[:, [0, 2]] = label[:, [0, 2]] * w
        label[:, [1, 3]] = label[:, [1, 3]] * h

        T = torchvision.transforms.Resize([h, w])

        Padding = torch.nn.ZeroPad2d((0, self.image_size - w, 0, self.image_size - h))
        image = Padding(T(image))

        assert list(image.size()) == [3, self.image_size, self.image_size]
        return image, label
draw_bbox.py
python 复制代码
# 路径: ./dataset/draw_bbox.py
import torchvision.transforms as F
import numpy as np
from PIL import ImageDraw, ImageFont
import matplotlib.pyplot as plt


colors = ['Pink', 'Crimson', 'Magenta', 'Indigo', 'BlueViolet',
          'Blue', 'GhostWhite', 'LightSteelBlue', 'Brown', 'SkyBlue',
          'Tomato', 'SpringGreen', 'Green', 'Yellow', 'Olive',
          'Gold', 'Wheat', 'Orange', 'Gray', 'Red']


def draw(image, bbox, classes, show_conf=False, conf_th=0.0):
    """
    :param image: tensor
    :param bbox: tensor xmin, ymin, xmax, ymax
    """
    keys = list(classes.keys())
    values = list(classes.values())
	
    # 设置字体(包括大小)
    font = ImageFont.truetype('arial.ttf', 10)

    transform = F.ToPILImage()
    image = transform(image)
    draw_image = ImageDraw.Draw(image)

    bbox = np.array(bbox.cpu())

    for b in bbox:
        print(b)
        if show_conf and b[-2] < conf_th:
            continue
        draw_image.rectangle(list(b[:4]), outline=colors[int(b[-1])], width=3)
        if show_conf:
            draw_image.text(list(b[:2] + 5), keys[values.index(int(b[-1]))] + ' {:.2f}'.format(b[-2]),
                            fill=colors[int(b[-1])], font=font)
        else:
            draw_image.text(list(b[:2] + 5), keys[values.index(int(b[-1]))],
                            fill=colors[int(b[-1])], font=font)

    plt.figure()
    plt.imshow(image)
    plt.show()

模型

darknet.py

根据论文的描述,首先构建Darknet的Backbone部分

下图是Backbone的基本结构请添加图片描述,源站可能有防盗链机制,建议将图片保存下来直接上传

请添加图片描述

下文指出了darknet全部使用LeakyReLU并且参数为0.1,除了最后的全连接层不需要激活函数。

python 复制代码
# 路径: ./model/darknet.py
import torch.nn as nn


def conv(in_ch, out_ch, k_size=3, stride=1, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, k_size, stride, padding, bias=False),
        nn.LeakyReLU(0.1)
    )


def make_layer(param):
    layers = []
    if not isinstance(param[0], list):
        param = [param]
    for p in param:
        layers.append(conv(*p))
    return nn.Sequential(*layers)


class Block(nn.Module):
    def __init__(self, param, use_pool=True):
        super(Block, self).__init__()

        self.conv = make_layer(param)
        self.pool = nn.MaxPool2d(2)
        self.use_pool = use_pool

    def forward(self, x):
        x = self.conv(x)
        if self.use_pool:
            x = self.pool(x)
        return x


class DarkNet(nn.Module):
    def __init__(self):
        super(DarkNet, self).__init__()
		
        self.conv1 = Block([[3, 64, 7, 2, 3]])
        self.conv2 = Block([[64, 192, 3, 1, 1]])
        self.conv3 = Block([[192, 128, 1, 1, 0],
                            [128, 256, 3, 1, 1],
                            [256, 256, 1, 1, 0],
                            [256, 512, 3, 1, 1]])
        self.conv4 = Block([[512, 256, 1, 1, 0],
                            [256, 512, 3, 1, 1],
                            [512, 256, 1, 1, 0],
                            [256, 512, 3, 1, 1],
                            [512, 256, 1, 1, 0],
                            [256, 512, 3, 1, 1],
                            [512, 256, 1, 1, 0],
                            [256, 512, 3, 1, 1],
                            [512, 512, 1, 1, 0],
                            [512, 1024, 3, 1, 1]])
        self.conv5 = Block([[1024, 512, 1, 1, 0],
                            [512, 1024, 3, 1, 1],
                            [1024, 512, 1, 1, 0],
                            [512, 1024, 3, 1, 1],
                            [1024, 1024, 3, 1, 1],
                            [1024, 1024, 3, 2, 1]], False)
        self.conv6 = Block([[1024, 1024, 3, 1, 1],
                            [1024, 1024, 3, 1, 1]], False)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        return x


if __name__ == "__main__":
    import torch
    x = torch.randn([1, 3, 448, 448])

    net = DarkNet()
    print(net)
    out = net(x)
    print(out.size())
resnet.py

作者还说明了Darknet需要预训练,由于在ImageNet上进行预训练耗时过长,我选择使用修改过的Resnet50作为Backbone并使用Pytorch官方的预训练参数。

python 复制代码
# 路径: ./model/resnet.py
import torch
from torchvision.models.resnet import ResNet, Bottleneck


"""
通过继承Pytorch的ResNet代码,重写其中的_forward_impl来去除最后的avgpool与fc层
此外我将Resnet50原有的layer4省略并额外增加了两个maxpool层,使得Resnet输出的特征图与Darknet一致
均为[1024, 7, 7]
"""
class ResNet_(ResNet):
    def __init__(self, block, layers):
        super(ResNet_, self).__init__(block=block, layers=layers)

    def _forward_impl(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.maxpool(x)
        x = self.layer3(x)
        x = self.maxpool(x)
        return x

    def forward(self, x):
        return self._forward_impl(x)


def _resnet(block, layers, pretrained):
    model = ResNet_(block, layers)
    if pretrained is not None:
        state_dict = torch.load(pretrained)
        model.load_state_dict(state_dict)
    return model


def resnet_1024ch(pretrained=None) -> ResNet:
    resnet = _resnet(Bottleneck, [3, 4, 6, 3], pretrained)
    return resnet


if __name__ == '__main__':
    x = torch.randn([1, 3, 448, 448])
    net = resnet_1024ch('resnet50-19c8e357.pth')
    print(net)

    y = net(x)
    print(y.size())
yolo.py

接下来复现yolo v1的目标检测核心部分。

​ Yolo v1的目标检测思路是将一张图像分割成S × S个grid cell,每个grid cell都会预测B个bbox,同时每个grid cell都会预测一个类别,作为其预测出的B个bbox的共同类别。如果一个目标的中心坐标位于某个grid cell中,那么这个grid cell就负责预测这个目标。一个bbox由x,y,w,h,conf五个参数表示,分别是中心X坐标,中心Y坐标(这里的x、y是相对于grid cell的x、y),宽,高,object置信度。其中这五个参数都是百分比的表示形式,x、y需要除以grid cell的宽、高;w、h则需要除以整张图片的宽、高;对于object conf,如果没有对应bbox则label为0,如果有对应bbox则label为预测出的bbox与对应bbox的IoU交并比。而且在测试时最终conf=object conf * conditional class probabilities。最终结果是xywh都在0到1之间。负责预测类别的tensor长度应为class个数。模型结构上最后一个全连接层的大小则是S × S × ( B * 5 + C )。

​ 最后预测bbox的全连接层之间有Dropout层,rate=0.5


损失函数上,YOLO v1以均方损失为为主,在部分参数的损失计算上使用了一些技巧。如w、h计算中就先开根号使得小目标的损失更为显著,加入λcoord=5和λnoobj=0.5以增加xy的损失值(0到1中开根号会让值变大),降低框出背景bbox的损失以适应该类bbox较多的状况。复现过程中xywhc的五个损失值都直接计算、class的损失值由CorssEntropyLoss函数直接计算。文中也定义了预测的bbox与label中的bbox匹配方法------IoU匹配。以IoU较大的一对bbox确定对应关系,而且这个过程取走一对之后不能重复取。只有对应的bbox才能计算x y w h conf各类损失。如果一个预测bbox没有对应的label bbox,那么就认为其标记到了背景,只计算conf。class的损失是有负责label bbox的grid cell需要计算的。因此其计算次数只是S2次。

这里需注意计算损失时如果没有使用pytorch中官方定义的损失函数,那么需要先通过sigmoid函数将模型输出限制到0~1之间才能进行直接的损失计算。

yolo.py的代码了yolo模型、yolo损失计算、yolo后处理三个部分。对前两个模块,yolo.py都有相应的测试代码,而后处理部分需要在test.py中测试,其不会被yolo.py内部调用。

python 复制代码
# 路径: ./model/yolo.py
import numpy as np

from model.darknet import DarkNet
from model.resnet import resnet_1024ch

import torch
import torch.nn as nn
import torchvision


# yolo模型
class yolo(nn.Module):
    def __init__(self, s, cell_out_ch, backbone_name, pretrain=None):
        """
        return: [s, s, cell_out_ch]
        """

        super(yolo, self).__init__()

        self.s = s
        self.backbone = None
        self.conv = None
        if backbone_name == 'darknet':
            self.backbone = DarkNet()
        elif backbone_name == 'resnet':
            self.backbone = resnet_1024ch(pretrained=pretrain)
        self.backbone_name = backbone_name

        assert self.backbone is not None, 'Wrong backbone name'

        self.fc = nn.Sequential(
            nn.Linear(1024 * s * s, 4096),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(4096, s * s * cell_out_ch)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = x.view(batch_size, self.s ** 2, -1)
        return x


# yolo损失计算
class yolo_loss:
    def __init__(self, device, s, b, image_size, num_classes):
        self.device = device
        self.s = s
        self.b = b
        self.image_size = image_size
        self.num_classes = num_classes
        self.batch_size = 0

    def __call__(self, input, target):
        """
        :param input: (yolo net output)
                      tensor[s, s, b*5 + n_class] bbox: b * (c_x, c_y, w, h, obj_conf), class1_p, class2_p.. %
        :param target: (dataset) tensor[n_bbox] bbox: x_min, ymin, xmax, ymax, class
        :return: loss tensor

        grid type: [[bbox, ..], [], ..] -> bbox_in_grid: c_x(%), c_y(%), w(%), h(%), class(int)

        target to grid type
        if s = 7 -> grid idx: 1 -> 49
        由于没有使用PyTorch的损失函数,所以需要先分离不同的batch分别计算损失
        """
        self.batch_size = input.size(0)
		
        # label预处理
        target = [self.label_direct2grid(target[i]) for i in range(self.batch_size)]

        # IoU 匹配predictor和label
        # 以Predictor为基准,每个Predictor都有且仅有一个需要负责的Target(前提是Predictor所在Grid Cell有Target中心位于此)
        # x, y, w, h, c
        match = []
        conf = []
        for i in range(self.batch_size):
            m, c = self.match_pred_target(input[i], target[i])
            match.append(m)
            conf.append(c)

        loss = torch.zeros([self.batch_size], dtype=torch.float, device=self.device)
        xy_loss = torch.zeros_like(loss)
        wh_loss = torch.zeros_like(loss)
        conf_loss = torch.zeros_like(loss)
        class_loss = torch.zeros_like(loss)
        for i in range(self.batch_size):
            loss[i], xy_loss[i], wh_loss[i], conf_loss[i], class_loss[i] = \
                self.compute_loss(input[i], target[i], match[i], conf[i])
        return torch.mean(loss), torch.mean(xy_loss), torch.mean(wh_loss), torch.mean(conf_loss), torch.mean(class_loss)

    def label_direct2grid(self, label):
        """
        :param label: dataset type: xmin, ymin, xmax, ymax, class
        :return: label: grid type, if the grid doesn't have object -> put None
        将label转换为c_x, c_y, w, h, conf再根据不同的grid cell分类,并转换成百分比形式
        若一个grid cell中没有label则都用None代替
        """
        output = [None for _ in range(self.s ** 2)]
        size = self.image_size // self.s  # h, w

        n_bbox = label.size(0)
        label_c = torch.zeros_like(label)

        label_c[:, 0] = (label[:, 0] + label[:, 2]) / 2
        label_c[:, 1] = (label[:, 1] + label[:, 3]) / 2
        label_c[:, 2] = abs(label[:, 0] - label[:, 2])
        label_c[:, 3] = abs(label[:, 1] - label[:, 3])
        label_c[:, 4] = label[:, 4]

        idx_x = [int(label_c[i][0]) // size for i in range(n_bbox)]
        idx_y = [int(label_c[i][1]) // size for i in range(n_bbox)]

        label_c[:, 0] = torch.div(torch.fmod(label_c[:, 0], size), size)
        label_c[:, 1] = torch.div(torch.fmod(label_c[:, 1], size), size)
        label_c[:, 2] = torch.div(label_c[:, 2], self.image_size)
        label_c[:, 3] = torch.div(label_c[:, 3], self.image_size)

        for i in range(n_bbox):
            idx = idx_y[i] * self.s + idx_x[i]
            if output[idx] is None:
                output[idx] = torch.unsqueeze(label_c[i], dim=0)
            else:
                output[idx] = torch.cat([output[idx], torch.unsqueeze(label_c[i], dim=0)], dim=0)
        return output

    def match_pred_target(self, input, target):
        match = []
        conf = []
        with torch.no_grad():
            input_bbox = input[:, :self.b * 5].reshape(-1, self.b, 5)
            ious = [match_get_iou(input_bbox[i], target[i], self.s, i)
                    for i in range(self.s ** 2)]
            for iou in ious:
                if iou is None:
                    match.append(None)
                    conf.append(None)
                else:
                    keep = np.ones([len(iou[0])], dtype=bool)
                    m = []
                    c = []
                    for i in range(self.b):
                        if np.any(keep) == False:
                            break
                        idx = np.argmax(iou[i][keep])
                        np_max = np.max(iou[i][keep])
                        m.append(np.argwhere(iou[i] == np_max).tolist()[0][0])
                        c.append(np.max(iou[i][keep]))
                        keep[idx] = 0
                    match.append(m)
                    conf.append(c)
        return match, conf

    def compute_loss(self, input, target, match, conf):
        # 计算损失
        ce_loss = nn.CrossEntropyLoss()

        input_bbox = input[:, :self.b * 5].reshape(-1, self.b, 5)
        input_class = input[:, self.b * 5:].reshape(-1, self.num_classes)

        input_bbox = torch.sigmoid(input_bbox)
        loss = torch.zeros([self.s ** 2], dtype=torch.float, device=self.device)
        xy_loss = torch.zeros_like(loss)
        wh_loss = torch.zeros_like(loss)
        conf_loss = torch.zeros_like(loss)
        class_loss = torch.zeros_like(loss)
        # 不同grid cell分别计算再求和
        for i in range(self.s ** 2):
            # 0 xy_loss, 1 wh_loss, 2 conf_loss, 3 class_loss
            l = torch.zeros([4], dtype=torch.float, device=self.device)
            # Neg
            if target[i] is None:
                # λ_noobj = 0.5
                obj_conf_target = torch.zeros([self.b], dtype=torch.float, device=self.device)
                l[2] = torch.sum(torch.mul(0.5, torch.pow(input_bbox[i, :, 4] - obj_conf_target, 2)))
            else:
                # λ_coord = 5
                l[0] = torch.mul(5, torch.sum(torch.pow(input_bbox[i, :, 0] - target[i][match[i], 0], 2) +
                                              torch.pow(input_bbox[i, :, 1] - target[i][match[i], 1], 2)))

                l[1] = torch.mul(5, torch.sum(torch.pow(torch.sqrt(input_bbox[i, :, 2]) -
                                                        torch.sqrt(target[i][match[i], 2]), 2) +
                                              torch.pow(torch.sqrt(input_bbox[i, :, 3]) -
                                                        torch.sqrt(target[i][match[i], 3]), 2)))
                obj_conf_target = torch.tensor(conf[i], dtype=torch.float, device=self.device)
                l[2] = torch.sum(torch.pow(input_bbox[i, :, 4] - obj_conf_target, 2))

                l[3] = ce_loss(input_class[i].unsqueeze(dim=0).repeat(target[i].size(0), 1),
                               target[i][:, 4].long())
            loss[i] = torch.sum(l)
            xy_loss[i] = torch.sum(l[0])
            wh_loss[i] = torch.sum(l[1])
            conf_loss[i] = torch.sum(l[2])
            class_loss[i] = torch.sum(l[3])
        return torch.sum(loss), torch.sum(xy_loss), torch.sum(wh_loss), torch.sum(conf_loss), torch.sum(class_loss)


def cxcywh2xyxy(bbox):
    """
    :param bbox: [bbox, bbox, ..] tensor c_x(%), c_y(%), w(%), h(%), c
    """
    bbox[:, 0] = bbox[:, 0] - bbox[:, 2] / 2
    bbox[:, 1] = bbox[:, 1] - bbox[:, 3] / 2
    bbox[:, 2] = bbox[:, 0] + bbox[:, 2]
    bbox[:, 3] = bbox[:, 1] + bbox[:, 3]
    return bbox


def match_get_iou(bbox1, bbox2, s, idx):
    """
    :param bbox1: [bbox, bbox, ..] tensor c_x(%), c_y(%), w(%), h(%), c
    :return:
    """

    if bbox1 is None or bbox2 is None:
        return None

    bbox1 = np.array(bbox1.cpu())
    bbox2 = np.array(bbox2.cpu())
	
    # c_x, c_y转换为对整张图片的百分比
    bbox1[:, 0] = bbox1[:, 0] / s
    bbox1[:, 1] = bbox1[:, 1] / s
    bbox2[:, 0] = bbox2[:, 0] / s
    bbox2[:, 1] = bbox2[:, 1] / s

    # c_x, c_y加上grid cell左上角左边变成完整坐标
    grid_pos = [(j / s, i / s) for i in range(s) for j in range(s)]
    bbox1[:, 0] = bbox1[:, 0] + grid_pos[idx][0]
    bbox1[:, 1] = bbox1[:, 1] + grid_pos[idx][1]
    bbox2[:, 0] = bbox2[:, 0] + grid_pos[idx][0]
    bbox2[:, 1] = bbox2[:, 1] + grid_pos[idx][1]

    bbox1 = cxcywh2xyxy(bbox1)
    bbox2 = cxcywh2xyxy(bbox2)

    # %
    return get_iou(bbox1, bbox2)


def get_iou(bbox1, bbox2):
    """
    :param bbox1: [bbox, bbox, ..] tensor xmin ymin xmax ymax
    :param bbox2:
    :return: area:
    """

    s1 = abs(bbox1[:, 2] - bbox1[:, 0]) * abs(bbox1[:, 3] - bbox1[:, 1])
    s2 = abs(bbox2[:, 2] - bbox2[:, 0]) * abs(bbox2[:, 3] - bbox2[:, 1])

    ious = []
    for i in range(bbox1.shape[0]):
        xmin = np.maximum(bbox1[i, 0], bbox2[:, 0])
        ymin = np.maximum(bbox1[i, 1], bbox2[:, 1])
        xmax = np.minimum(bbox1[i, 2], bbox2[:, 2])
        ymax = np.minimum(bbox1[i, 3], bbox2[:, 3])

        in_w = np.maximum(xmax - xmin, 0)
        in_h = np.maximum(ymax - ymin, 0)

        in_s = in_w * in_h

        iou = in_s / (s1[i] + s2 - in_s)
        ious.append(iou)
    ious = np.array(ious)
    return ious


def nms(bbox, conf_th, iou_th):
    bbox = np.array(bbox.cpu())

    bbox[:, 4] = bbox[:, 4] * bbox[:, 5]

    bbox = bbox[bbox[:, 4] > conf_th]
    order = np.argsort(-bbox[:, 4])

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)

        iou = get_iou(np.array([bbox[i]]), bbox[order[1:]])[0]
        inds = np.where(iou <= iou_th)[0]
        order = order[inds + 1]
    return bbox[keep]


# yolo后处理
def output_process(output, image_size, s, b, conf_th, iou_th):
    """
    输入是包含batch的模型输出
    :return output: list[], bbox: xmin, ymin, xmax, ymax, obj_conf, classes_conf, classes
    """
    batch_size = output.size(0)
    size = image_size // s

    output = torch.sigmoid(output)

    # Get Class
    # 将class conf依次添加到bbox中
    classes_conf, classes = torch.max(output[:, :, b * 5:], dim=2)
    classes = classes.unsqueeze(dim=2).repeat(1, 1, 2).unsqueeze(dim=3)
    classes_conf = classes_conf.unsqueeze(dim=2).repeat(1, 1, 2).unsqueeze(dim=3)
    bbox = output[:, :, :b * 5].reshape(batch_size, -1, b, 5)

    bbox = torch.cat([bbox, classes_conf, classes], dim=3)

    # To Direct
    # 百分比形式转直接表示
    bbox[:, :, :, [0, 1]] = bbox[:, :, :, [0, 1]] * size
    bbox[:, :, :, [2, 3]] = bbox[:, :, :, [2, 3]] * image_size
	
    # 添加grid cell坐标
    grid_pos = [(j * image_size // s, i * image_size // s) for i in range(s) for j in range(s)]

    def to_direct(bbox):
        for i in range(s ** 2):
            bbox[i, :, 0] = bbox[i, :, 0] + grid_pos[i][0]
            bbox[i, :, 1] = bbox[i, :, 1] + grid_pos[i][1]
        return bbox

    bbox_direct = torch.stack([to_direct(b) for b in bbox])
    bbox_direct = bbox_direct.reshape(batch_size, -1, 7)

    # cxcywh to xyxy
    bbox_direct[:, :, 0] = bbox_direct[:, :, 0] - bbox_direct[:, :, 2] / 2
    bbox_direct[:, :, 1] = bbox_direct[:, :, 1] - bbox_direct[:, :, 3] / 2
    bbox_direct[:, :, 2] = bbox_direct[:, :, 0] + bbox_direct[:, :, 2]
    bbox_direct[:, :, 3] = bbox_direct[:, :, 1] + bbox_direct[:, :, 3]

    bbox_direct[:, :, 0] = torch.maximum(bbox_direct[:, :, 0], torch.zeros(1))
    bbox_direct[:, :, 1] = torch.maximum(bbox_direct[:, :, 1], torch.zeros(1))
    bbox_direct[:, :, 2] = torch.minimum(bbox_direct[:, :, 2], torch.tensor([image_size]))
    bbox_direct[:, :, 3] = torch.minimum(bbox_direct[:, :, 3], torch.tensor([image_size]))
	
    # 整合不同batch中的bbox
    bbox = [torch.tensor(nms(b, conf_th, iou_th)) for b in bbox_direct]
    bbox = torch.stack(bbox)
    return bbox


if __name__ == "__main__":
    import torch

    # Test yolo
    x = torch.randn([1, 3, 448, 448])

    # B * 5 + n_classes
    net = yolo(7, 2 * 5 + 20, 'resnet', pretrain=None)
    # net = yolo(7, 2 * 5 + 20, 'darknet', pretrain=None)
    print(net)
    out = net(x)
    print(out)
    print(out.size())

    # Test yolo_loss
    # 测试时假设 s=2, class=2
    s = 2
    b = 2
    image_size = 448  # h, w
    input = torch.tensor([[[0.45, 0.24, 0.22, 0.3, 0.35, 0.54, 0.66, 0.7, 0.8, 0.8, 0.17, 0.9],
                           [0.37, 0.25, 0.5, 0.3, 0.36, 0.14, 0.27, 0.26, 0.33, 0.36, 0.13, 0.9],
                           [0.12, 0.8, 0.26, 0.74, 0.8, 0.13, 0.83, 0.6, 0.75, 0.87, 0.75, 0.24],
                           [0.1, 0.27, 0.24, 0.37, 0.34, 0.15, 0.26, 0.27, 0.37, 0.34, 0.16, 0.93]]])
    target = [torch.tensor([[200, 200, 353, 300, 1],
                            [220, 230, 353, 300, 1],
                            [15, 330, 200, 400, 0],
                            [100, 50, 198, 223, 1],
                            [30, 60, 150, 240, 1]], dtype=torch.float)]

    criterion = yolo_loss('cpu', 2, 2, image_size, 2)
    loss = criterion(input, target)
    print(loss)

scheduler.py

论文作者也给出了他的训练参数。我们首先复现学习率的调整方式。原文分成三段学习率,每一段都有不同的保持epochs长度。此外,我额外增加了热身训练阶段。

python 复制代码
# 路径: ./scheduler.py
from torch.optim.lr_scheduler import _LRScheduler


class Scheduler(_LRScheduler):
    def __init__(self, optimizer, step_warm_ep, lr_start, step_1_lr, step_1_ep,
                 step_2_lr, step_2_ep, step_3_lr, step_3_ep, last_epoch=-1):
        self.optimizer = optimizer
        self.lr_start = lr_start
        self.step_warm_ep = step_warm_ep
        self.step_1_lr = step_1_lr
        self.step_1_ep = step_1_ep
        self.step_2_lr = step_2_lr
        self.step_2_ep = step_2_ep
        self.step_3_lr = step_3_lr
        self.step_3_ep = step_3_ep
        self.last_epoch = last_epoch

        super(Scheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch == 0:
            return [self.lr_start for _ in self.optimizer.param_groups]
        lr = self._compute_lr_from_epoch()

        return [lr for _ in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return self.base_lrs

    def _compute_lr_from_epoch(self):
        if self.last_epoch < self.step_warm_ep:
            lr = ((self.step_1_lr - self.lr_start)/self.step_warm_ep) * self.last_epoch + self.lr_start
        elif self.last_epoch < self.step_warm_ep + self.step_1_ep:
            lr = self.step_1_lr
        elif self.last_epoch < self.step_warm_ep + self.step_1_ep + self.step_2_ep:
            lr = self.step_2_lr
        elif self.last_epoch < self.step_warm_ep + self.step_1_ep + self.step_2_ep + self.step_3_ep:
            lr = self.step_3_lr
        else:
            lr = self.step_3_lr
        return lr


if __name__ == '__main__':
    import torch.nn as nn
    import torch.optim as optim

    import numpy as np
    import matplotlib.pyplot as plt

    import warnings
    warnings.filterwarnings('ignore')

    batch_size = 16
    epoch = 135
    scheduler_params = {
        'lr_start': 1e-3,
        'step_warm_ep': 10,
        'step_1_lr': 1e-2,
        'step_1_ep': 75,
        'step_2_lr': 1e-3,
        'step_2_ep': 30,
        'step_3_lr': 1e-4,
        'step_3_ep': 20
    }

    model = nn.Sequential(
        nn.Linear(1, 10),
        nn.Linear(10, 1)
    )
    optimizer = optim.SGD(model.parameters(), lr=scheduler_params['lr_start'])
    scheduler = Scheduler(optimizer, **scheduler_params)

    lrs = []
    for _ in range(epoch):
        lrs.append(optimizer.param_groups[0]['lr'])
        scheduler.step()
    print(lrs)

    lrs = np.array(lrs)

    # 使用plt可视化学习率
    plt.figure()
    plt.plot(lrs)
    plt.show()

train.py

训练参数主要参考了论文原文。由于batch size较小,学习率都等比减小。优化器选择了SGD并带有momentum与weight_decay。代码中有冻结Backbone的选项。

root0712 存储了VOC07/12两年数据集的根目录

model_root 存储模型路径

backbone 'resnet'则代表使用修改后的resnet;'darknet'则代表使用原文中的darknet

pretrain None则不使用预训练模型;或直接输入预训练权重地址

with_amp 混合精度选项

transforms 需要使用代码中定义的transforms而不是PyTorch直接给出的transforms

start_epoch 中断训练重启的开始epoch

epoch 总epoch数

freeze_backbone_till 若为-1则不冻结backbone,其他数则会冻结backbone直到该epoch后解冻

python 复制代码
# 路径: ./train.py
from dataset.data import VOC0712Dataset
from dataset.transform import *
from model.yolo import yolo, yolo_loss
from scheduler import Scheduler

import torch
from torch.utils.data import DataLoader
from torch import optim
from torch.cuda.amp import autocast, GradScaler

from tqdm import tqdm
import pandas as pd
import json
import os

import warnings


class CFG:
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    root0712 = [r'F:\AI\Dataset\VOC2007\VOCdevkit\VOC2007', r'F:\AI\Dataset\VOC2012\VOCdevkit\VOC2012']
    class_path = r'./dataset/classes.json'
    model_root = r'./log/ex7'
    # 若model_path是一个指向权重文件的str路径,那么会将模型传入指定模型权重
    model_path = None
	# 这里没有编写创建文件夹的代码,直接运行需要手动将文件夹创建好
    
    backbone = 'resnet'
    pretrain = 'model/resnet50-19c8e357.pth'
    # 混合精度可选,若为False则采用常规精度
    with_amp = True
    S = 7
    B = 2
    image_size = 448

    transforms = Compose([
        ToTensor(),
        RandomHorizontalFlip(0.5),
        Resize(448, keep_ratio=False)
    ])

    start_epoch = 0
    epoch = 135
    batch_size = 16
    num_workers = 2
	
    # freeze_backbone_till = -1 则不冻结
    freeze_backbone_till = 30

    scheduler_params = {
        'lr_start': 1e-3 / 4,
        'step_warm_ep': 10,
        'step_1_lr': 1e-2 / 4,
        'step_1_ep': 75,
        'step_2_lr': 1e-3 / 4,
        'step_2_ep': 40,
        'step_3_lr': 1e-4 / 4,
        'step_3_ep': 10
    }

    momentum = 0.9
    weight_decay = 0.0005


def collate_fn(batch):
    return tuple(zip(*batch))


class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def train():
    device = torch.device(CFG.device)
    print('Train:\nDevice:{}'.format(device))

    with open(CFG.class_path, 'r') as f:
        json_str = f.read()
        classes = json.loads(json_str)
        CFG.num_classes = len(classes)

    train_ds = VOC0712Dataset(CFG.root0712, CFG.class_path, CFG.transforms, 'train')
    test_ds = VOC0712Dataset(CFG.root0712, CFG.class_path, CFG.transforms, 'test')

    train_dl = DataLoader(train_ds, batch_size=CFG.batch_size, shuffle=True,
                          num_workers=CFG.num_workers, collate_fn=collate_fn)
    test_dl = DataLoader(test_ds, batch_size=CFG.batch_size, shuffle=False,
                         num_workers=CFG.num_workers, collate_fn=collate_fn)

    yolo_net = yolo(s=CFG.S, cell_out_ch=CFG.B * 5 + CFG.num_classes, backbone_name=CFG.backbone, pretrain=CFG.pretrain)
    yolo_net.to(device)

    if CFG.model_path is not None:
        yolo_net.load_state_dict(torch.load(CFG.model_path))

    if CFG.freeze_backbone_till != -1:
        print('Freeze Backbone')
        for param in yolo_net.backbone.parameters():
            param.requires_grad_(False)

    param = [p for p in yolo_net.parameters() if p.requires_grad]
    optimizer = optim.SGD(param, lr=CFG.scheduler_params['lr_start'],
                          momentum=CFG.momentum, weight_decay=CFG.weight_decay)
    criterion = yolo_loss(CFG.device, CFG.S, CFG.B, CFG.image_size, len(train_ds.classes))
    scheduler = Scheduler(optimizer, **CFG.scheduler_params)
    scaler = GradScaler()

    for _ in range(CFG.start_epoch):
        scheduler.step()

    best_train_loss = 1e+9
    train_losses = []
    test_losses = []
    lrs = []
    for epoch in range(CFG.start_epoch, CFG.epoch):
        if CFG.freeze_backbone_till != -1 and epoch >= CFG.freeze_backbone_till:
            print('Unfreeze Backbone')
            for param in yolo_net.backbone.parameters():
                param.requires_grad_(True)
            CFG.freeze_backbone_till = -1
        # Train
        yolo_net.train()
        loss_score = AverageMeter()
        dl = tqdm(train_dl, total=len(train_dl))
        for images, labels in dl:
            batch_size = len(labels)

            images = torch.stack(images)
            images = images.to(device)
            labels = [label.to(device) for label in labels]

            optimizer.zero_grad()

            if CFG.with_amp:
                with autocast():
                    outputs = yolo_net(images)
                    loss, xy_loss, wh_loss, conf_loss, class_loss = criterion(outputs, labels)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = yolo_net(images)
                loss, xy_loss, wh_loss, conf_loss, class_loss = criterion(outputs, labels)

                loss.backward()
                optimizer.step()

            loss_score.update(loss.detach().item(), batch_size)
            dl.set_postfix(Mode='Train', AvgLoss=loss_score.avg, Loss=loss.detach().item(),
                           Epoch=epoch, LR=optimizer.param_groups[0]['lr'])
        lrs.append(optimizer.param_groups[0]['lr'])
        scheduler.step()

        train_losses.append(loss_score.avg)
        print('Train Loss: {:.4f}'.format(loss_score.avg))

        if best_train_loss > loss_score.avg:
            print('Save yolo_net to {}'.format(os.path.join(CFG.model_root, 'yolo.pth')))
            torch.save(yolo_net.state_dict(), os.path.join(CFG.model_root, 'yolo.pth'))
            best_train_loss = loss_score.avg

        loss_score.reset()
        with torch.no_grad():
            # Test
            yolo_net.eval()
            dl = tqdm(test_dl, total=len(test_dl))
            for images, labels in dl:
                batch_size = len(labels)

                images = torch.stack(images)
                images = images.to(device)
                labels = [label.to(device) for label in labels]

                outputs = yolo_net(images)
                loss, xy_loss, wh_loss, conf_loss, class_loss = criterion(outputs, labels)

                loss_score.update(loss.detach().item(), batch_size)
                dl.set_postfix(Mode='Test', AvgLoss=loss_score.avg, Loss=loss.detach().item(), Epoch=epoch)
        test_losses.append(loss_score.avg)
        print('Test Loss: {:.4f}'.format(loss_score.avg))

        df = pd.DataFrame({'Train Loss': train_losses, 'Test Loss': test_losses, 'LR': lrs})
        df.to_csv(os.path.join(CFG.model_root, 'result.csv'), index=True)


if __name__ == '__main__':
    warnings.filterwarnings('ignore')
    train()

voc_eval.py

测试时计算mAP的代码

python 复制代码
# 路径: ./voc_eval.py
"""Adapted from:
    @longcw faster_rcnn_pytorch: https://github.com/longcw/faster_rcnn_pytorch
    @rbgirshick py-faster-rcnn https://github.com/rbgirshick/py-faster-rcnn
    Licensed under The MIT License [see LICENSE for details]
    这个文件的代码有改动,我删除了缓存的代码,因此也不需要传入缓存文件夹
"""

import xml.etree.ElementTree as ET
import numpy as np
import os


def voc_ap(rec, prec, use_07_metric=False):
    """ ap = voc_ap(rec, prec, [use_07_metric])
    Compute VOC AP given precision and recall.
    If use_07_metric is true, uses the
    VOC 07 11 point method (default:False).
    """
    if use_07_metric:
        # 11 point metric
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            if np.sum(rec >= t) == 0:
                p = 0
            else:
                p = np.max(prec[rec >= t])
            ap = ap + p / 11.
    else:
        # correct AP calculation
        # first append sentinel values at the end
        mrec = np.concatenate(([0.], rec, [1.]))
        mpre = np.concatenate(([0.], prec, [0.]))

        # compute the precision envelope
        for i in range(mpre.size - 1, 0, -1):
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

        # to calculate area under PR curve, look for points
        # where X axis (recall) changes value
        i = np.where(mrec[1:] != mrec[:-1])[0]

        # and sum (\Delta recall) * prec
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap


def parse_rec(filename):
    """ Parse a PASCAL VOC xml file """
    tree = ET.parse(filename)
    objects = []
    for obj in tree.findall('object'):
        obj_struct = {}
        obj_struct['name'] = obj.find('name').text
        obj_struct['pose'] = obj.find('pose').text
        obj_struct['truncated'] = int(obj.find('truncated').text)
        obj_struct['difficult'] = int(obj.find('difficult').text)
        bbox = obj.find('bndbox')
        obj_struct['bbox'] = [int(bbox.find('xmin').text),
                              int(bbox.find('ymin').text),
                              int(bbox.find('xmax').text),
                              int(bbox.find('ymax').text)]
        objects.append(obj_struct)

    return objects


def voc_eval(detpath,
             annopath,
             imagesetfile,
             classname,
             ovthresh=0.5,
             use_07_metric=False):
    """rec, prec, ap = voc_eval(detpath,
                                annopath,
                                imagesetfile,
                                classname,
                                [ovthresh],
                                [use_07_metric])
    Top level function that does the PASCAL VOC evaluation.
    detpath: Path to detections
        detpath.format(classname) should produce the detection results file.
    annopath: Path to annotations
        annopath.format(imagename) should be the xml annotations file.
    imagesetfile: Text file containing the list of images, one image per line.
    classname: Category name (duh)
    cachedir: Directory for caching the annotations
    [ovthresh]: Overlap threshold (default = 0.5)
    [use_07_metric]: Whether to use VOC07's 11 point AP computation
        (default False)
    """
    # assumes detections are in detpath.format(classname)
    # assumes annotations are in annopath.format(imagename)
    # assumes imagesetfile is a text file with each line an image name
    # cachedir caches the annotations in a pickle file

    # read list of images
    with open(imagesetfile, 'r') as f:
        lines = f.readlines()
    imagenames = [x.strip() for x in lines]

    # load annots
    recs = {}
    for i, imagename in enumerate(imagenames):
        recs[imagename] = parse_rec(annopath.format(imagename))
        # if i % 100 == 0:
        #     print('Reading annotation for {:d}/{:d}'.format(
        #         i + 1, len(imagenames)))

    # extract gt objects for this class
    class_recs = {}
    npos = 0
    for imagename in imagenames:
        R = [obj for obj in recs[imagename] if obj['name'] == classname]
        bbox = np.array([x['bbox'] for x in R])
        difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
        det = [False] * len(R)
        npos = npos + sum(~difficult)
        class_recs[imagename] = {'bbox': bbox,
                                 'difficult': difficult,
                                 'det': det}

    # read dets
    detfile = detpath.format(classname)
    with open(detfile, 'r') as f:
        lines = f.readlines()

    splitlines = [x.strip().split(' ') for x in lines]
    image_ids = [x[0] for x in splitlines]
    confidence = np.array([float(x[1]) for x in splitlines])
    BB = np.array([[float(z) for z in x[2:]] for x in splitlines])

    # sort by confidence
    sorted_ind = np.argsort(-confidence)
    sorted_scores = np.sort(-confidence)
    BB = BB[sorted_ind, :]
    image_ids = [image_ids[x] for x in sorted_ind]

    # go down dets and mark TPs and FPs
    nd = len(image_ids)
    tp = np.zeros(nd)
    fp = np.zeros(nd)
    for d in range(nd):
        R = class_recs[image_ids[d]]
        bb = BB[d, :].astype(float)
        ovmax = -np.inf
        BBGT = R['bbox'].astype(float)

        if BBGT.size > 0:
            # compute overlaps
            # intersection
            ixmin = np.maximum(BBGT[:, 0], bb[0])
            iymin = np.maximum(BBGT[:, 1], bb[1])
            ixmax = np.minimum(BBGT[:, 2], bb[2])
            iymax = np.minimum(BBGT[:, 3], bb[3])
            iw = np.maximum(ixmax - ixmin + 1., 0.)
            ih = np.maximum(iymax - iymin + 1., 0.)
            inters = iw * ih

            # union
            uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
                   (BBGT[:, 2] - BBGT[:, 0] + 1.) *
                   (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)

            overlaps = inters / uni
            ovmax = np.max(overlaps)
            jmax = np.argmax(overlaps)
        if ovmax > ovthresh:
            if not R['difficult'][jmax]:
                if not R['det'][jmax]:
                    tp[d] = 1.
                    R['det'][jmax] = 1
                else:
                    fp[d] = 1.
        else:
            fp[d] = 1.

    # compute precision recall
    fp = np.cumsum(fp)
    tp = np.cumsum(tp)
    rec = tp / float(npos)
    # avoid divide by zero in case the first detection matches a difficult
    # ground truth
    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
    ap = voc_ap(rec, prec, use_07_metric)

    return rec, prec, ap

test.py

python 复制代码
# 路径: ./test.py
from dataset.data import VOC0712Dataset, Compose, ToTensor, Resize
from dataset.draw_bbox import draw
from model.yolo import yolo, output_process
from voc_eval import voc_eval

import torch

from tqdm import tqdm
import json


class CFG:
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    root = r'F:\AI\Dataset\VOC2007\VOCdevkit\VOC2007'
    class_path = r'dataset/classes.json'
    model_path = r'log/ex7/yolo.pth'
    # 这里也要手动新建det文件夹,用于保存每个class的目标情况
    detpath = r'det\{}.txt'
    annopath = r'F:\AI\Dataset\VOC2007\VOCdevkit\VOC2007\Annotations\{}.xml'
    imagesetfile = r'F:\AI\Dataset\VOC2007\VOCdevkit\VOC2007\ImageSets\Main\test.txt'
    classname = None

    test_range = None
    show_image = False
    get_ap = True

    backbone = 'resnet'
    S = 7
    B = 2
    image_size = 448

    get_info = True
    transforms = Compose([
        ToTensor(),
        Resize(448, keep_ratio=False)
    ])

    num_classes = 0

    conf_th = 0.2
    iou_th = 0.5


def test():
    device = torch.device(CFG.device)
    print('Test:\nDevice:{}'.format(device))

    dataset = VOC0712Dataset(CFG.root, CFG.class_path, CFG.transforms, 'test',
                             data_range=CFG.test_range, get_info=CFG.get_info)

    with open(CFG.class_path, 'r') as f:
        json_str = f.read()
        classes = json.loads(json_str)
        CFG.classname = list(classes.keys())
        CFG.num_classes = len(CFG.classname)

    yolo_net = yolo(s=CFG.S, cell_out_ch=CFG.B * 5 + CFG.num_classes, backbone_name=CFG.backbone)
    yolo_net.to(device)

    yolo_net.load_state_dict(torch.load(CFG.model_path))

    bboxes = []
    with torch.no_grad():
        for image, label, image_name, input_size in tqdm(dataset):
            image = image.unsqueeze(dim=0)
            image = image.to(device)

            output = yolo_net(image)
            output = output_process(output.cpu(), CFG.image_size, CFG.S, CFG.B, CFG.conf_th, CFG.iou_th)

            if CFG.show_image:
                draw(image.squeeze(dim=0), output.squeeze(dim=0), classes, show_conf=True)
                draw(image.squeeze(dim=0), label, classes, show_conf=True)

            # 还原
            output[:, :, [0, 2]] = output[:, :, [0, 2]] * input_size[0] / CFG.image_size
            output[:, :, [1, 3]] = output[:, :, [1, 3]] * input_size[1] / CFG.image_size

            output = output.squeeze(dim=0).numpy().tolist()
            if len(output) > 0:
                pred = [[image_name, output[i][-3] * output[i][-2]] + output[i][:4] + [int(output[i][-1])]
                        for i in range(len(output))]
                bboxes += pred
    det_list = [[] for _ in range(CFG.num_classes)]
    for b in bboxes:
        det_list[b[-1]].append(b[:-1])

    if CFG.get_ap:
        map = 0
        for idx in range(CFG.num_classes):
            file_path = CFG.detpath.format(CFG.classname[idx])
            txt = '\n'.join([' '.join([str(i) for i in item]) for item in det_list[idx]])
            with open(file_path, 'w') as f:
                f.write(txt)

            rec, prec, ap = voc_eval(CFG.detpath, CFG.annopath, CFG.imagesetfile, CFG.classname[idx])
            print(rec)
            print(prec)
            map += ap
            print(ap)

        map /= CFG.num_classes
        print('mAP', map)


if __name__ == '__main__':
    test()
相关推荐
q567315232 分钟前
在 Bash 中获取 Python 模块变量列
开发语言·python·bash
是萝卜干呀3 分钟前
Backend - Python 爬取网页数据并保存在Excel文件中
python·excel·table·xlwt·爬取网页数据
代码欢乐豆4 分钟前
数据采集之selenium模拟登录
python·selenium·测试工具
喵~来学编程啦11 分钟前
【论文精读】LPT: Long-tailed prompt tuning for image classification
人工智能·深度学习·机器学习·计算机视觉·论文笔记
深圳市青牛科技实业有限公司24 分钟前
【青牛科技】应用方案|D2587A高压大电流DC-DC
人工智能·科技·单片机·嵌入式硬件·机器人·安防监控
狂奔solar39 分钟前
yelp数据集上识别潜在的热门商家
开发语言·python
Tassel_YUE40 分钟前
网络自动化04:python实现ACL匹配信息(主机与主机信息)
网络·python·自动化
水豚AI课代表1 小时前
分析报告、调研报告、工作方案等的提示词
大数据·人工智能·学习·chatgpt·aigc
几两春秋梦_1 小时前
符号回归概念
人工智能·数据挖掘·回归
聪明的墨菲特i1 小时前
Python爬虫学习
爬虫·python·学习