【目标检测】Anchor-based模型:基于K-means算法获取自制数据集的Anchor(yolo源码)

在Anchor-based目标检测模型中,根据数据集选择合适的Anchor有利于加快模型的收敛速度以及减少模型的边框预测误差。本篇文章首先介绍Anchor在目标检测模型中的作用;然后介绍K-means聚类算法;最后介绍yolo源码自制数据集的Anchor的获取方法。

本文目录

1 Anchor的理解

Anchor-based目标检测模型中(SSD,YOLO v3/v4/v5等),其通过预先设定的具有不同宽高比尺度的Anchor得到目标的粗略边框,并训练回归参数对Anchor进行修正以得到目标的准确边框。以YOLOv5为例,其基于Anchor的检测流程如图1所示,在不同尺度的输出特征图使用不同尺度的Anchor实现对不同大小目标的检测,在一个检测单元中,使用不同宽高比的Anchor对目标进行预测。

  • 在较大的特征图上使用较小的Anchor,实现对小目标的检测;
  • 在中等的特征图上使用中等的Anchor,实现对中等大小目标的检测;
  • 在较小的特征图上使用较大的Anchor,实现对大目标的检测;


图1 Anchor设定方法

在目标检测模型中,目标的类别预测结果不会随着其在图像中位置的改变而改变;但目标的边框信息会随着其在图像中位置的改变而改变,通过设定Anchor,实现目标边框对Anchor的偏移量预测,该偏移量不会随着目标在图像中位置的改变而改变,更有利于模型的回归,符合神经网络位移不变性

模型在训练中,根据Anchor和目标实际边框之间的IoU大小,分配用于检测该目标的Anchor。根据数据集科学的设置Anchor有利于加快模型训练时的收敛速度减少模型对目标的漏检率

2 K-means聚类算法

K-means是最常见的聚类算法,算法接受未标记的数据集,然后将数据聚集成不同的组,其方法如下:

  1. 首先选取K个随机的点,作为聚类中心;
  2. 对于数据集中的每一个数据,计算其与聚类中心点的距离,并将其与距离最近的中心点关联起来,与同一个中心点关联的所有点聚成一类;
  3. 计算每一类的平均值,将该类的中心点移动到平均值的位置;
  4. 重复步骤2-3,直到中心点不再改变。

图2 K-means分类示例

3 基于K-means获取自制数据集Anchor

Anchors获取

基于Scipy库中的kmeans函数计算K-means聚类(kmean_anchors)

  • 输入:shapes(原图大小),labels(图像标签),n(聚类中心点个数),img_size(模型训练图像大小)
  • 输出:anchors(在模型训练图像上的绝对大小)
python 复制代码
def kmean_anchors(shapes, labels, n=9, img_size=640):
	'''
	shapes: 数据集图像大小
	labels: 数据集标签 [N, 5] 5->(cls, x, y, w, h)(相对大小) 
	n: Anchors数量(在img_size上绝对大小)
	img_sie:模型训练图像大小
	'''
	from scipy.cluster.vq import kmeans # 导入kmeans函数, 返回-> k:Anchors; distortion:每一类内点到中心点的平均距离
	
	# 数据集图像输入模型的大小(保持原图比例, 长边缩放为img_size)
	shapes = img_size * shapes / shapes.max(1, keepdims=True)
	# 得到所有目标的边框(在输入图像上绝对大小)
	wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, labels)]) 
	# 滤除边框像素过小的目标
	wh = wh0[(wh0 >= 2.0).any(1)].astype(np.float32)
	# K-means
	try:
		assert n <= len(wh)  # 中心点个数应小于等于边框数目
		s = wh.std(0)  # 坐标标准差
		k = kmeans(wh / s, n, iter=30)[0] * s  # points
		assert n == len(k) # # kmeans 可能无法找到足够的点,若输入数据太少或太小
		except Exception:
			k = np.sort(npr.rand(n * 2)).reshape(n, 2) * img_size  # 随机初始化
	k = k[np.argsort(k.prode(1))]  # 根据面积从小到大对得到的Anchors进行排序
	print('Anchors:', k)

随机生成300点,使用以上函数得到的聚类结果示例如图3所示。

图3 聚类结果示例图

Anchors评价

在YOLO中,判断Anchors与目标边框是否匹配有以下两种指标:

  • 根据Anchors和目标实际边框的宽高比(r > anchors_t)
  • 根据Anchors和目标实际边框的IoU(iou > iou_t)
python 复制代码
def metrix(k, wh, m='ratio', thr=0.25):
	k = k[np.argsort(k.prod(1))]  # 根据面积从小到大对得到的Anchors进行排序
	r = wh[:, None] / k[None]  # 计算目标与所有Anchor的宽高比
	# 目标匹配的所有Anchors
	if m == 'ratio':  # 宽高比指标
		x = torch.min(r, 1 / r).min(2)[0]  # 计算所有宽高比的最小值:目标/Anchors; Anchors/目标
	elif m == 'iou':  # iou指标
		x = wh_iou(wh, k)
	best = x.max(1)[0]  # 目标匹配的所有Anchors, 目标匹配的最好Anchor(大小最接近)
	fitness = (best * (best > thr).float()).mean()  # fitness[0, 1], 反映目标实际边框与Anchors的匹配情况
	return x, best, fitness

def print_result(x, best, thr=0.25, img_size=640):
	bpr, aat = bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n  # 最好的匹配结果, 所有匹配结果
    s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \
        f'n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
        f'past_thr={x[x > thr].mean():.3f}-mean: '
    # 添加Anchors信息
    for x in k:
        s += '%i,%i, ' % (round(x[0]), round(x[1]))
    print(s[:-2])

Anchors更新(遗传算法)

根据评价指标,基于遗传算法更新Anchors,以得到在数据集中表现最好的Anchors

python 复制代码
def ga_anchors(k, wh, gen=1000):
	_, _, f = metrix(k, wh)
	sh, mp, s = k.shape, 0.9, 0.1  # 遗传算法中系数更新
	for _ in range(gen):
		v = np.ones(sh)  # 随机更新系数
		while (v==1).all():
			v = ((np.random.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
		kg = (k.copy() * v).clip(min=2.0)  # 根据系数构造新的Anchors
		_, _, fg = metrig(kg, wh)  # 计算新的fitness
		if fg > f:  # 若效果更好, 则更新k
			f, k = fg, kg.copy()  
			print('Evolving anchors with Genetic Algorithm: fitness = {f:.4f}')
相关推荐
肥猪猪爸1 小时前
使用卡尔曼滤波器估计pybullet中的机器人位置
数据结构·人工智能·python·算法·机器人·卡尔曼滤波·pybullet
readmancynn1 小时前
二分基本实现
数据结构·算法
萝卜兽编程1 小时前
优先级队列
c++·算法
盼海2 小时前
排序算法(四)--快速排序
数据结构·算法·排序算法
一直学习永不止步2 小时前
LeetCode题练习与总结:最长回文串--409
java·数据结构·算法·leetcode·字符串·贪心·哈希表
goomind2 小时前
YOLOv8实战木材缺陷识别
人工智能·yolo·目标检测·缺陷检测·pyqt5·木材缺陷识别
Rstln3 小时前
【DP】个人练习-Leetcode-2019. The Score of Students Solving Math Expression
算法·leetcode·职场和发展
芜湖_3 小时前
【山大909算法题】2014-T1
算法·c·单链表
珹洺3 小时前
C语言数据结构——详细讲解 双链表
c语言·开发语言·网络·数据结构·c++·算法·leetcode
吾门3 小时前
YOLO入门教程(三)——训练自己YOLO11实例分割模型并预测【含教程源码+一键分类数据集 + 故障排查】
yolo·分类·数据挖掘