第2题-终端款型聚类识别 - problem_ide - CodeFun2000
python
import sys
import numpy as np
def solve():
# 读取数据
input_data = sys.stdin.read().strip().split()
if not input_data:
return
k = int(input_data[0])
m = int(input_data[1])
n = int(input_data[2])
# 将数据 reshape 成了 (m, 4) 的矩阵!
data = np.array(input_data[3:], dtype=np.float64).reshape(m, 4)
# 1. 初始化质心 (前 k 个点)
centers = data[0:k].copy()
# np.full(shape, fill_value, dtype)
# np.zeros(shape, dtype)
labels = np.zeros(m,dtype=int)
for stap in range(n):
old_centers = centers.copy()
for i in range(m):
distances = np.zeros(k)
for j in range(k):
dist = np.sqrt(np.sum((data[i]-old_centers[j] )**2))
distances[j] = dist
labels[i] = np.argmin(distances)
# 更新质心
for j in range(k):
cluster = data[labels==j]
if len(cluster)>0:
centers[j] = np.mean(cluster,axis=0)
max_shift = 0
# 检查是否收敛 (质心不怎么动了)
for j in range(k):
shift = np.sqrt(np.sum((centers[j]-old_centers[j] )**2))
if shift>max_shift:
max_shift = shift
if max_shift<1e-8:
break
counts = []
for i in range(k):
counts.append(np.sum(labels==i))
counts.sort()
print(" ".join(map(str,counts)))
if __name__ =='__main__':
solve()