YOLOv1 目标检测

相关文章

项目地址:YOLOv1 VOC 2007

笔者训练的权重地址:阿里云盘分享

10 秒文章速览

本文主要讲解 PASCAL VOC 2007 数据集的信息与加载

IOU 计算

在这里我们在设置一个计算 iou 的函数,只不过这个函数用于最终目标的检测,而非训练。这个就不解释了

python 复制代码
def iou(box1, box2):
    x1_min, y1_min, x1_max, y1_max = box1
    x2_min, y2_min, x2_max, y2_max = box2
    s1 = (x1_max - x1_min)*(y1_max - y1_min)
    s2 = (x2_max - x2_min)*(y2_max - y2_min)
    xmin = max(x1_min, x2_min)
    ymin = max(y1_min, y2_min)
    xmax = min(x1_max, x2_max)
    ymax = min(y1_max, y2_max)

    w = max(0, xmax - xmin)
    h = max(0, ymax - ymin)
    a1 = w * h
    a2 = s1 + s2 - a1
    iou = a1 / a2

    return iou

NMS

python 复制代码
# 非极大值抑制,筛选框
def NMS(bboxes, cond_threshold=0.2, iou_threshold=0.5):
    def filter_cond_score(x):
        # 计算两个预测框的置信度分数
        max_classes = max(x[:20])
        bbox1_cond_score = max_classes * x[24]
        bbox2_cond_score = max_classes * x[29]
        # 返回置信度分数大的预测框
        if bbox1_cond_score > bbox2_cond_score:
            return x[0:20] + x[20:25]
        else:
            return x[0:20] + x[25:30]

    # 每个网格只有一个预测框能活
    bboxes = map(filter_cond_score, bboxes)
    # 根据置信度分数从大到小排序
    bboxes = sorted(bboxes, key=lambda x: max(x[0:20]) * x[-1], reverse=True)
    # 筛选达不到阈值的预测框
    bboxes = list(filter(lambda x: max(x[0:20]) * x[-1] > cond_threshold, bboxes))

    lucky_bboxes = []
    while len(bboxes) != 0:
        lucky_bboxes.append(bboxes.pop(0))
        del_idx = []
        for num, box in enumerate(bboxes):
            # 还原预测框坐标
            x1, y1, w1, h1 = lucky_bboxes[-1][20:24]
            x2, y2, w2, h2 = box[20:24]
            x1_min, y1_min, x1_max, y1_max = x1-w1/2, y1-h1/2, x1+w1/2, y1+h1/2
            x2_min, y2_min, x2_max, y2_max = x2-w2/2, y2-h2/2, x2+w2/2, y2+h2/2

            if iou([x1_min, y1_min, x1_max, y1_max], [x2_min, y2_min, x2_max, y2_max]) > iou_threshold:
                del_idx.append(num)
            # 当存在一个预测框刚好在另一个预测框内的情况时,就不能单纯的计算 IOU 了
            if x1_min < x2_min < x2_max < x1_max and y1_min < y2_min < y2_max < y1_max:
                del_idx.append(num)

        # 批量过滤预测框
        for n, i in enumerate(del_idx):
            bboxes.pop(i-n)

    return lucky_bboxes

未进行 NMS 处理,对了同一个目标都存在多个预测框,效果不是很好

进行 NMS 处理后,过滤了无关紧要的预测框,very good!👍👍👍

目标检测

终于熬出头了,一切为了这一刻🎉

python 复制代码
def draw():
    # 读取图片
    # original_img:原图,new_img:经处理,可作为模型输入的图片
    path = path_set[n]
    original_img = tf.io.read_file(path)
    original_img = tf.image.decode_jpeg(original_img, channels=3).numpy()
    new_img = tf.image.resize(original_img, (448, 448))/255
    new_img = tf.expand_dims(new_img, 0)
    pred = model(new_img).numpy()[0]

    # 根据模型输出,先把归一化还原出原来的尺寸
    height, width = original_img.shape[:2]
    w_grid = width / 7
    h_grid = height / 7
    x_grid = np.array([range(7) for i in range(7)], dtype='float32') * w_grid
    y_grid = np.array([[i] * 7 for i in range(7)], dtype='float32') * h_grid
    pred[..., [20, 25]] = pred[..., [20, 25]] * w_grid + np.repeat(x_grid[..., np.newaxis], 2, -1)
    pred[..., [21, 26]] = pred[..., [21, 26]] * h_grid + np.repeat(y_grid[..., np.newaxis], 2, -1)
    pred[..., [22, 27]] *= width
    pred[..., [23, 28]] *= height

    # 调整输出的形状
    bboxes = pred.reshape(49, 30)
    bboxes = bboxes.tolist()

    # NMS 筛选
    bboxes = NMS(bboxes)

    retval, baseLine = cv2.getTextSize('abc', cv2.FONT_ITALIC, 1, 2)
    # 遍历 bboxes,并绘制框
    for i in bboxes:
        x, y, w, h = i[20:-1]
        cls = np.argmax(i[0:20])
        # 计算绘制坐标
        pt1 = int(x-w/2), int(y-h/2)
        pt2 = int(x+w/2), int(y+h/2)
        # 绘制预测框
        cv2.rectangle(original_img, pt1, pt2, classes[classes_name[cls]]['color'], 1)
        topleft = (pt1[0] + 3, pt1[1] + retval[1] + 3)
        cv2.putText(original_img, classes[classes_name[cls]]['name'], topleft, cv2.FONT_ITALIC, 0.6, classes[classes_name[cls]]['color'], 2)

    plt.figure(dpi=128)
    plt.axis('off')
    plt.imshow(original_img)
    plt.show()


n = 4
draw()

最终效果如下,当然这没什么值得骄傲的,毕竟这些用的是训练集中的数据

下面来看看对验证集数据的检测如何(似乎过得去吧,毕竟是笔者挑选过的😁)

相关推荐
CoovallyAIHub1 分钟前
工业质检新突破!YOLO-pdd多尺度PCB缺陷检测算法实现99%高精度
深度学习·算法·计算机视觉
悟乙己4 分钟前
译|生存分析Survival Analysis案例入门讲解(一)
人工智能·机器学习·数据挖掘·生存分析·因果推荐
无奈何杨6 分钟前
从“指点江山”到“赛博求雨”的心路历程
人工智能
老贾专利烩16 分钟前
智能健康项链专利拆解:ECG 与 TBI 双模态监测的硬件架构与信号融合
人工智能·科技·健康医疗
无奈何杨18 分钟前
MCP Server工具参数设计与AI约束指南
人工智能
青梅主码18 分钟前
中国在世界人工智能大会上发布《人工智能全球治理行动计划》:中美 AI 竞争白热化,贸易紧张局势下的全球治理新篇章
人工智能
loopdeloop30 分钟前
机器学习、深度学习与数据挖掘:核心技术差异、应用场景与工程实践指南
深度学习·机器学习·数据挖掘
loopdeloop32 分钟前
机器学习、深度学习与数据挖掘:三大技术领域的深度解析
深度学习·机器学习·数据挖掘
张较瘦_1 小时前
[论文阅读] 人工智能 + 软件工程 | CASCADE:用LLM+编译器技术破解JavaScript混淆难题
javascript·论文阅读·人工智能
呆头鹅AI工作室1 小时前
[2025CVPR-图象分类方向]CATANet:用于轻量级图像超分辨率的高效内容感知标记聚合
图像处理·人工智能·深度学习·目标检测·机器学习·计算机视觉·分类