【PyTorch】图像目标检测

图像目标检测是什么

Object Detection

判断图像中目标位置

目标检测两要素

  1. 分类:分类向量 p0, ..., pn
  2. 回归:回归边界框 x1, y1, x2, y2

模型如何完成目标检测

将3D张量映射到两个张量

  1. 分类张量:shape为 N, c+1
  2. 边界框张量:shape为 N, 4

Recent Advances in Deep Learning for Object Detection

边界框数量N如何确定?

传统方法------滑动窗策略

缺点:

  1. 重复计算量大
  2. 窗口大小难确定

利用卷积减少重复计算

重要概念:特征图一个像素 对应原图一块区域

目标检测模型简介

按流程分为:one-stage和two-stage

Faster RCNN

经典two stage检测网络

Faster RCNN 数据流

  1. Feature map:256, h_f, w_f
  2. 2 Softmax:num_anchors, h_f, w_f
  3. Regressors:num_anchors\*4, h_f, w_f
  4. NMS OUT:n_proposals=2000, 4
  5. ROI Layer:512, 256, 7, 7
  6. FC1 FC2:512, 1024
  7. c+1 Softmax: 512, c+1
  8. Regressors:512, (c+1)\*4

Faster RCNN 主要组件

  1. backbone
  2. rpn
  3. filter_proposals(NMS)
  4. roi_heads

Faster RCNN 行人检测

数据: PennFudanPed数据集, 70张行人照片共345行人标签

官方地址: http://www.cis.upenn.edu/\~jshi/ped_html/

模型: fasterrcnn_resnet50_fpn 进行finetune

目标检测推荐github:https://github.com/amusi/awesome-object-detection

代码如下:

python 复制代码
import os
import time
import torch.nn as nn
import torch
import random
import numpy as np
import torchvision.transforms as transforms
import torchvision
from PIL import Image
import torch.nn.functional as F
from my_dataset import PennFudanDataset
from common_tools import set_seed
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms import functional as F
import enviroments

set_seed(1)  # 设置随机种子

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# classes_coco
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]


def vis_bbox(img, output, classes, max_vis=40, prob_thres=0.4):
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(img, aspect='equal')

    out_boxes = output_dict["boxes"].cpu()
    out_scores = output_dict["scores"].cpu()
    out_labels = output_dict["labels"].cpu()

    num_boxes = out_boxes.shape[0]
    for idx in range(0, min(num_boxes, max_vis)):

        score = out_scores[idx].numpy()
        bbox = out_boxes[idx].numpy()
        class_name = classes[out_labels[idx]]

        if score < prob_thres:
            continue

        ax.add_patch(plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False,
                                   edgecolor='red', linewidth=3.5))
        ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')
    plt.show()
    plt.close()


class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class RandomHorizontalFlip(object):
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
        return image, target


class ToTensor(object):
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target


if __name__ == "__main__":

    # config
    LR = 0.001
    num_classes = 2
    batch_size = 1
    start_epoch, max_epoch = 0, 5
    train_dir = enviroments.pennFudanPed_data_dir
    train_transform = Compose([ToTensor(), RandomHorizontalFlip(0.5)])

    # step 1: data
    train_set = PennFudanDataset(data_dir=train_dir, transforms=train_transform)

    # 收集batch data的函数
    def collate_fn(batch):
        return tuple(zip(*batch))

    train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=collate_fn)

    # step 2: model
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) # replace the pre-trained head with a new one

    model.to(device)

    # step 3: loss
    # in lib/python3.6/site-packages/torchvision/models/detection/roi_heads.py
    # def fastrcnn_loss(class_logits, box_regression, labels, regression_targets)

    # step 4: optimizer scheduler
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=LR, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    # step 5: Iteration

    for epoch in range(start_epoch, max_epoch):

        model.train()
        for iter, (images, targets) in enumerate(train_loader):

            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # if torch.cuda.is_available():
            #     images, targets = images.to(device), targets.to(device)

            loss_dict = model(images, targets)  # images is list; targets is [ dict["boxes":**, "labels":**], dict[] ]

            losses = sum(loss for loss in loss_dict.values())

            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} ".format(
                epoch, max_epoch, iter + 1, len(train_loader), losses.item()))

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

        lr_scheduler.step()

    # test
    model.eval()

    # config
    vis_num = 5
    vis_dir = os.path.join(BASE_DIR, "..", "..", "data", "PennFudanPed", "PNGImages")
    img_names = list(filter(lambda x: x.endswith(".png"), os.listdir(vis_dir)))
    random.shuffle(img_names)
    preprocess = transforms.Compose([transforms.ToTensor(), ])

    for i in range(0, vis_num):

        path_img = os.path.join(vis_dir, img_names[i])
        # preprocess
        input_image = Image.open(path_img).convert("RGB")
        img_chw = preprocess(input_image)

        # to device
        if torch.cuda.is_available():
            img_chw = img_chw.to('cuda')
            model.to('cuda')

        # forward
        input_list = [img_chw]
        with torch.no_grad():
            tic = time.time()
            print("input img tensor shape:{}".format(input_list[0].shape))
            output_list = model(input_list)
            output_dict = output_list[0]
            print("pass: {:.3f}s".format(time.time() - tic))

        # visualization
        vis_bbox(input_image, output_dict, COCO_INSTANCE_CATEGORY_NAMES, max_vis=20, prob_thres=0.5)  # for 2 epoch for nms
相关推荐
冬奇Lab1 分钟前
如何让 AI Skill 质量有据可查?Benchmark 驱动的评测体系设计
人工智能·agent
腾科IT教育1 小时前
Spring AI Alibaba 向量(VectorStore)
人工智能·spring·microsoft
IT_陈寒1 小时前
React中useEffect依赖项这个坑我居然踩了三天
前端·人工智能·后端
江畔柳前堤1 小时前
github实战指南02-仓库管理与 Issue
人工智能·深度学习·github·信号处理·caffe·wps·issue
邵宇然2 小时前
内存分配优化:基于 Unsafe 指针与内存对齐的 Rust 区域分配器
人工智能
海兰2 小时前
【游戏】迷雾镇(Mist Town)AI 沙箱游戏详细设计与部署指南(附源代码)
人工智能·游戏
小赖同学啊2 小时前
智能连接器集群化高可用生产方案
linux·运维·人工智能
ZStack开发者社区2 小时前
基于AI Agent的ZCF API文档全链路自动化
运维·人工智能·自动化
沈麽鬼2 小时前
别瞎用AI写代码!90%开发者都搞错了AI编程的底层逻辑
人工智能·ai编程·trae
小陈爱编程2 小时前
我终于把 Codex 的 API 配置理顺了:从踩坑到跑通
人工智能