【目标检测】YOLOv5算法实现(九):模型预测

本系列文章记录本人硕士阶段YOLO系列目标检测算法自学及其代码实现的过程。其中算法具体实现借鉴于ultralytics YOLO源码Github,删减了源码中部分内容,满足个人科研需求。

本系列文章主要以YOLOv5为例完成算法的实现,后续修改、增加相关模块即可实现其他版本的YOLO算法。

文章地址:
YOLOv5算法实现(一):算法框架概述
YOLOv5算法实现(二):模型加载
YOLOv5算法实现(三):数据集加载
YOLOv5算法实现(四):正样本匹配与损失计算
YOLOv5算法实现(五):预测结果后处理
YOLOv5算法实现(六):评价指标及实现
YOLOv5算法实现(七):模型训练
YOLOv5算法实现(八):模型验证
YOLOv5算法实现(九):模型预测

本文目录

引言

本篇文章综合之前文章中的功能,实现模型的预测。模型预测的逻辑如图1所示。


图1 模型预测流程

模型预测(predict.py)

python 复制代码
def predice():
    img_size = 640  # 必须是32的整数倍 [416, 512, 608]
    file = "yolov5s"
    cfg = f"cfg/models/{file}.yaml"  # 改成生成的.cfg文件
    weights_path = f"weights/{file}/{file}.pt"  # 改成自己训练好的权重文件
    json_path = "data/dataset.json"  # json标签文件
    img_path = "test.jpg"
    save_path = f"results/{file}/test_result8.jpg"
    assert os.path.exists(cfg), "cfg file {} dose not exist.".format(cfg)
    assert os.path.exists(weights_path), "weights file {} dose not exist.".format(weights_path)
    assert os.path.exists(json_path), "json file {} dose not exist.".format(json_path)
    assert os.path.exists(img_path), "image file {} dose not exist.".format(img_path)

    with open(json_path, 'r') as f:
        class_dict = json.load(f)

    category_index = {str(v): str(k) for k, v in class_dict.items()}

    input_size = (img_size, img_size)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")
    model = Model(cfg, ch=3, nc=3)
    weights_dict = torch.load(weights_path, map_location='cpu')
    weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
    model.load_state_dict(weights_dict, strict=False)
    model.to(device)

    model.eval()
    with torch.no_grad():
        # init
        img = torch.zeros((1, 3, img_size, img_size), device=device)
        model(img)

        img_o = cv2.imread(img_path)  # BGR
        assert img_o is not None, "Image Not Found " + img_path

        img = letterbox(img_o, new_shape=input_size, auto=True, color=(0, 0, 0))[0]

        # Convert
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img)

        img = torch.from_numpy(img).to(device).float()
        img /= 255.0  # scale (0, 255) to (0, 1)
        img = img.unsqueeze(0)  # add batch dimension

        t1 = torch_utils.time_synchronized()
        pred = model(img)[0]  # only get inference result
        t2 = torch_utils.time_synchronized()
        print("inference time: {}s".format(t2 - t1))
        print('model: {}'.format(file))



        pred = utils.non_max_suppression(pred, conf_thres=0.1, iou_thres=0.6, multi_label=True)[0]
        t3 = time.time()
        print("post-processing time: {}s".format(t3 - t2))



        # process detections
        pred[:, :4] = utils.scale_coords(img.shape[2:], pred[:, :4], img_o.shape).round()


        bboxes = pred[:, :4].detach().cpu().numpy()
        scores = pred[:, 4].detach().cpu().numpy()
        classes = pred[:, 5].detach().cpu().numpy().astype(np.int) + 1

        pil_img = Image.fromarray(img_o[:, :, ::-1])
        plot_img = draw_objs(pil_img,
                             bboxes,
                             classes,
                             scores,
                             category_index=category_index,
                             box_thresh=0.2,
                             line_thickness=3,
                             font='arial.ttf',
                             font_size=30)
        plt.imshow(plot_img)
        plt.show()
        # 保存预测的图片结果

        plot_img.save(save_path)


if __name__ == "__main__":
    predict()
相关推荐
棒棒的皮皮6 小时前
【深度学习】YOLO模型评估之指标、可视化曲线分析
人工智能·深度学习·yolo·计算机视觉
MF_AI7 小时前
大型烟雾火灾检测识别数据集:25w+图像,2类,yolo标注
图像处理·人工智能·深度学习·yolo·计算机视觉
AI小怪兽8 小时前
基于YOLO11的航空安保与异常无人机检测系统(Python源码+数据集+Pyside6界面)
开发语言·人工智能·python·yolo·计算机视觉·无人机
源来猿往9 小时前
yolov8n结构化剪枝
算法·yolo·剪枝
才不做选择13 小时前
基于 YOLOv8 的部落冲突 (Clash of Clans) 目标检测系统
人工智能·python·yolo·目标检测
棒棒的皮皮1 天前
【深度学习】YOLO核心原理介绍
人工智能·深度学习·yolo·计算机视觉
吃人陈乐游刘2 天前
08实战经验yoloV8部署(2026年01月)
yolo
思通数科x2 天前
传统监控事后难追责?AI 违规识别智能系统让响应时间缩短
图像处理·人工智能·深度学习·目标检测·机器学习·计算机视觉·自然语言处理
熬夜不洗澡2 天前
如何在pycharm中使用Yolo
ide·yolo·pycharm
智驱力人工智能2 天前
从占座到智座 非授权人员座位占用监测系统的产品化思考与实践 椅位占用检测 非员工座位占用AI预警 边缘计算非授权座位识别设备
人工智能·opencv·算法·安全·yolo·计算机视觉·边缘计算