Yolo检测器中的anchor聚类(python实现)
为什么不使用cpp了?使用cpp手搓机器学习的效率还是太低了,还是得上python,有numpy库效率提高一大截
📝 题目描述
难度: 🟡 中等
标签:机器学习kmeans
照例先看输入输出,看不懂再看题面
输入描述:

输出描述:

样例:
输入:
bash
12 4 20
12 23
34 21
43 23
199 23
34 23
108 12
200 107
12 78
123 110
34 23
56 48
78 66
输出:
bash
133 94
121 27
36 22
12 50
说明:

从输入输出描述就可以看出来要干什么了,但是聚类中心、结束条件、距离公式还不知道,因此要从题面找一下:

以及最后的提示:

条件齐全,这下就可以开始做题了
� 解题思路
模拟即可。但是要注意每次迭代完成后,要求聚类中心向下取整。
步骤:
- 读取数据
- 确定聚类中心
- 计算到每个聚类中心的dis(记得用1-iou计算),分配框到距离最近的中心
- 分配完所有点,使用题目给的公式移动聚类中心(向下取整),并计算新旧聚类中心移动的总距离
- 判断是否达到聚类完成条件,否则重复3到5
- 按照聚类中心框的面积从大到小排序,输出排序后的框
� 代码实现
这里放的是原始实现,没用numpy,后面研究一下在补个numpy版本的
python
import sys
import math
def iou(l1:list,l2:list):
w1,h1=l1
w2,h2=l2
res=0
intersection=min(w1,w2)*min(h1,h2)
union=w1*h1+w2*h2-intersection
IOU=intersection/(union + 1e-16)
dis=1-IOU
return dis
def center_comp(lists):
ave_w=ave_h=0
if len(lists)==0:
return -1,-1
for items in lists:
ave_w+=items[0]
ave_h+=items[1]
ave_w/=len(lists)
ave_h/=len(lists)
return ave_w,ave_h
def kmeans(N,K,T,check_list):
center=[]
for items in check_list[:K]:
## 注意这里需要用深拷贝,因为后面要修改center的值,这里浅拷贝会同时修改check_list
center.append(list(map(float,items.copy())))
# 开始迭代
for i in range(T):
center_points=[[] for _ in range(K)]
for items in check_list:
min_index=[-1,1e9]
# 逐个计算中心距离
for _ in range(K):
sq=iou(items,center[_])
if min_index[0]==-1:
min_index[0]=_
min_index[1]=sq
continue
if sq<min_index[1]:
min_index[0]=_
min_index[1]=sq
center_points[min_index[0]].append(items)
# 一次迭代完成,开始计算新中心和移动距离
sum_move=0
for _ in range(K):
new_w,new_h=center_comp(center_points[_])
# 如果返回的是-1,-1说明该类没有点,那么不动该点,继续检查下一个
if new_w==-1 and new_h==-1:
continue
sum_move+=iou([new_w,new_h],center[_])
## 这里注意要向下取整
center[_][0],center[_][1]=int(new_w),int(new_h)
if sum_move<=1e-4:
break
center.sort(key=lambda x:x[0]*x[1],reverse=True)
return center
# 读数
# 计算距离交并比,并划入对应的类
# 重新计算中心
# 迭代
def main():
line=sys.stdin.readline()
line=line.strip()
N,K,T=map(int,line.split())
check_list=[]
for _ in range(N):
check_w,check_h=map(int,sys.stdin.readline().split())
check_list.append([check_w,check_h])
res=kmeans(N,K,T,check_list)
for _ in res:
print('%d %d'%(_[0],_[1]))
if __name__=='__main__':
main()
📊 复杂度分析
时空复杂度都比较高,这里不计算了
日期: 2026-4-11
1h写完但是无法运行,使用豆包查bug改了半小时,最终1.5h完成题目,下次尽量先自己对照检查一下问题,python改bug相对简单很多