使用golang实现k-means

k-means聚类算法

K-Means是一种无监督算法,其目标是将数据进行分类。分类个数要求已知。

k-means流程

  1. 随机确定K个点作为质心、
  2. 找到离每个点最近的质心,将这个点分配到这个质心代表的簇里
  3. 再对每个簇进行计算,以点簇的均值点作为新的质心
  4. 如果新的质心和上一轮的不一样,则迭代进行2-3步骤,直到质心位置稳定

本文目标

  1. 在golang中实现k-means算法。
  2. 使用matplotlib绘制聚类散点图。
  3. 尝试并行处理。
  4. 与sklearn结果对比。

执行结果

左侧:输出

中间:本文效果

右侧:sklearn效果

Q: 为什么样本一样,结果不同?

A: 两方面,首先算法的结束方法中阈值不同,然后是初始k均值点选择不同。

确定数据结构

我把kMeans封装成了对象,此外还有Point。实现如下

go 复制代码
type (
	Point struct {
		X, Y float64
	}
	KMeans struct {
		points       []Point
		k            int
		distanceFunc distanceFunc
		avgPoints    []Point
	}
	distanceFunc func(p1, p2 Point) float64
)

生成随机样本

go 复制代码
func generateRandomPoints(n int) []tool.Point {
	points := make([]tool.Point, n)
	for i := 0; i < n; i++ {
		points[i] = tool.Point{
			X: rand.Float64() * 100,
			Y: rand.Float64() * 100,
		}
	}
	return points
}

确定距离度量函数

go 复制代码
	distance := func(p1, p2 tool.Point) float64 {
		return math.Sqrt(math.Pow(p1.X-p2.X, 2) + math.Pow(p1.Y-p2.Y, 2))
	}

初始化kmeans对象

go 复制代码
func (kMeans *KMeans) Init(k int, points []Point, distanceFunc distanceFunc) {
	kMeans.k = k
	kMeans.points = points
	kMeans.distanceFunc = distanceFunc
	kMeans.avgPoints = make([]Point, kMeans.k)
}
func (kMeans *KMeans) initializeAvgPoints() {
	copy(kMeans.avgPoints, kMeans.points[:kMeans.k])
}

确定算法大致流程

go 复制代码
func (kMeans *KMeans) Do(checkFunc func(oldCentroids, newCentroids []Point) bool) ([][]Point, int) {
	kMeans.initializeAvgPoints()
	var (
		clusters     [][]Point
		tmpAvgPoints []Point
		count        int
	)

	for {
		// 获取聚类含有哪些点
		clusters = kMeans.computeDistanceToAvgPoints()
		// 更新聚类中心
		tmpAvgPoints = kMeans.updateAvgPoints(clusters)
		// 检查质心位置的变化
		if checkFunc(kMeans.avgPoints, tmpAvgPoints) {
			break
		}

		// 更新质心位置
		kMeans.avgPoints = tmpAvgPoints
		count++
	}
	return clusters, count
}

计算聚类(尝试并发)

Q: 为什么可以并发?

A: 因为计算聚类时,对每个点的运算的独立的,依赖的数据不会在计算时修改。

go 复制代码
func (kMeans *KMeans) computeDistanceToAvgPoints() [][]Point {
	type BakPoint struct {
		i int
		p Point
	}
	clusters := make([][]Point, len(kMeans.avgPoints))
	resultCh := make(chan BakPoint, len(kMeans.points))
	for _, point := range kMeans.points {
		computeMinDistanceForSignlePoint := func(point Point, avgPoints []Point, distanceFunc func(p1, p2 Point) float64, ch chan BakPoint) {
			minDistance := struct {
				d float64
				i int
			}{
				d: math.MaxFloat64,
				i: -1,
			}
			for i, avgPoint := range avgPoints {
				if d := distanceFunc(avgPoint, point); d < minDistance.d {
					minDistance.d = d
					minDistance.i = i
				}
			}
			ch <- BakPoint{p: point, i: minDistance.i}
		}
		go computeMinDistanceForSignlePoint(point, kMeans.avgPoints, kMeans.distanceFunc, resultCh)
	}

	for i := 0; i < len(kMeans.points); i++ {
		result := <-resultCh
		clusters[result.i] = append(clusters[result.i], result.p)
	}

	close(resultCh)
	return clusters
}

更新K均值点

go 复制代码
func (kMeans *KMeans) updateAvgPoints(clusters [][]Point) []Point {
	centroids := make([]Point, kMeans.k)
	for i, cluster := range clusters {
		sumX, sumY := 0.0, 0.0
		for _, point := range cluster {
			sumX += point.X
			sumY += point.Y
		}
		centroids[i].X = sumX / float64(len(cluster))
		centroids[i].Y = sumY / float64(len(cluster))
	}
	return centroids
}

使用matplotlib绘制图像

python 复制代码
import matplotlib.pyplot as plt


def readFromFile():
    clusters = []
    current_cluster = []

    with open('points.txt', 'r') as file:
        for line in file:
            if line.strip() == "----":
                if current_cluster:
                    clusters.append(current_cluster)
                    current_cluster = []
            else:
                x, y = map(float, line.split())
                current_cluster.append((x, y))

        if current_cluster:
            clusters.append(current_cluster)

    return clusters


clusters = readFromFile()

for cluster in clusters:
    X = [point[0] for point in cluster]
    Y = [point[1] for point in cluster]
    plt.scatter(X, Y, marker='.')

plt.xlabel('X')
plt.ylabel('Y')
plt.title('Clustered Scatter Plot')
plt.show()

使用sklearn计算

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

points = []
with open('rawpoints.txt', 'r') as file:
    for line in file:
        x, y = map(float, line.split())
        points.append([x, y])

X = np.array(points)

n_clusters = 5
kmeans = KMeans(n_clusters=n_clusters)
kmeans.fit(X)
labels = kmeans.labels_
centroids = kmeans.cluster_centers_

plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis', marker='.')
plt.scatter(centroids[:, 0], centroids[:, 1], marker='x', s=10, color='red', label='Centroids')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('K-Means Clustering')
plt.legend()
plt.show()

完整代码

项目结构如下

go 复制代码
KMeans-golang
│  go.mod
│  main.go
│  raw.py
│  show.py
└─tool
	kmeans.go

main.go

go 复制代码
package main

import (
	"fmt"
	"k-means/tool"
	"math"
	"math/rand"
	"os"
	"os/exec"
	"time"
)

func generateRandomPoints(n int) []tool.Point {
	points := make([]tool.Point, n)
	for i := 0; i < n; i++ {
		points[i] = tool.Point{
			X: rand.Float64() * 100,
			Y: rand.Float64() * 100,
		}
	}
	return points
}

func wrapper(name string, fun func()) {
	start := time.Now()
	fun()
	elapsed := time.Since(start)
	fmt.Printf("%s 函数执行时间:%s\n", name, elapsed)
	fmt.Printf("%s 函数执行时间(纳秒):%dns\n", name, elapsed.Nanoseconds())
}

func writeToFile(points [][]tool.Point) {
	file, err := os.Create("points.txt")
	if err != nil {
		fmt.Println("Failed to create file:", err)
		return
	}
	defer file.Close()

	for _, row := range points {
		for _, p := range row {
			_, err := fmt.Fprintf(file, "%f %f\n", p.X, p.Y)
			if err != nil {
				fmt.Println("Failed to write to file:", err)
				return
			}
		}
		_, err := fmt.Fprintf(file, "----\n")
		if err != nil {
			fmt.Println("Failed to write to file:", err)
			return
		}
	}
	fmt.Println("Data written to file successfully.")
}
func writeToFile2(points []tool.Point) {
	file, err := os.Create("rawpoints.txt")
	if err != nil {
		fmt.Println("Failed to create file:", err)
		return
	}
	defer file.Close()

	for _, p := range points {
		_, err := fmt.Fprintf(file, "%f %f\n", p.X, p.Y)
		if err != nil {
			fmt.Println("Failed to write to file:", err)
			return
		}
	}
	fmt.Println("Data written to file successfully.")
}
func main() {
	rand.Seed(time.Now().UnixNano())
	// 样本数据
	var points []tool.Point
	wrapper("生成样本", func() {
		points = generateRandomPoints(100)
	})

	k := 5
	distance := func(p1, p2 tool.Point) float64 {
		return math.Sqrt(math.Pow(p1.X-p2.X, 2) + math.Pow(p1.Y-p2.Y, 2))
	}
	kMeansObj := new(tool.KMeans)
	kMeansObj.Init(k, points, distance)

	var (
		finalClusters [][]tool.Point
		count         int
	)

	wrapper("执行算法", func() {
		finalClusters, count = kMeansObj.Do(func(oldCentroids, newCentroids []tool.Point) bool {
			epsilon := 0.000001
			for i := 0; i < len(oldCentroids); i++ {
				if distance(oldCentroids[i], newCentroids[i]) > epsilon {
					return false
				}
			}
			return true
		})
	})
	fmt.Println("count: ", count)
	wrapper("写入文件", func() {
		writeToFile2(points)
		writeToFile(finalClusters)
	})
	go func() {
		command := exec.Command("C:\\Projects\\PycharmProjects\\deelLearn\\venv\\Scripts\\python.exe", "raw.py")
		command.Run()
		command.Wait()
	}()
	command := exec.Command("C:\\Projects\\PycharmProjects\\deelLearn\\venv\\Scripts\\python.exe", "show.py")
	command.Run()
	command.Wait()
}

raw.py

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

points = []
with open('rawpoints.txt', 'r') as file:
    for line in file:
        x, y = map(float, line.split())
        points.append([x, y])

X = np.array(points)

n_clusters = 5
kmeans = KMeans(n_clusters=n_clusters)
kmeans.fit(X)
labels = kmeans.labels_
centroids = kmeans.cluster_centers_

plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis', marker='.')
plt.scatter(centroids[:, 0], centroids[:, 1], marker='x', s=10, color='red', label='Centroids')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('K-Means Clustering')
plt.legend()
plt.show()

show.py

python 复制代码
import matplotlib.pyplot as plt


def readFromFile():
    clusters = []
    current_cluster = []

    with open('points.txt', 'r') as file:
        for line in file:
            if line.strip() == "----":
                if current_cluster:
                    clusters.append(current_cluster)
                    current_cluster = []
            else:
                x, y = map(float, line.split())
                current_cluster.append((x, y))

        if current_cluster:
            clusters.append(current_cluster)

    return clusters


clusters = readFromFile()

for cluster in clusters:
    X = [point[0] for point in cluster]
    Y = [point[1] for point in cluster]
    plt.scatter(X, Y, marker='.')

plt.xlabel('X')
plt.ylabel('Y')
plt.title('Clustered Scatter Plot')
plt.show()

tool/kmeans.go

go 复制代码
package tool

import "math"

type (
	Point struct {
		X, Y float64
	}
	KMeans struct {
		points       []Point
		k            int
		distanceFunc distanceFunc
		avgPoints    []Point
	}
	distanceFunc func(p1, p2 Point) float64
)

func (kMeans *KMeans) Init(k int, points []Point, distanceFunc distanceFunc) {
	kMeans.k = k
	kMeans.points = points
	kMeans.distanceFunc = distanceFunc
	kMeans.avgPoints = make([]Point, kMeans.k)
}

func (kMeans *KMeans) Do(checkFunc func(oldCentroids, newCentroids []Point) bool) ([][]Point, int) {
	kMeans.initializeAvgPoints()
	var (
		clusters     [][]Point
		tmpAvgPoints []Point
		count        int
	)

	for {
		// 获取聚类含有哪些点
		clusters = kMeans.computeDistanceToAvgPoints()
		// 更新聚类中心
		tmpAvgPoints = kMeans.updateAvgPoints(clusters)
		// 检查质心位置的变化
		if checkFunc(kMeans.avgPoints, tmpAvgPoints) {
			break
		}

		// 更新质心位置
		kMeans.avgPoints = tmpAvgPoints
		count++
	}
	return clusters, count
}

func (kMeans *KMeans) initializeAvgPoints() {
	copy(kMeans.avgPoints, kMeans.points[:kMeans.k])
}

func (kMeans *KMeans) computeDistanceToAvgPoints() [][]Point {
	type BakPoint struct {
		i int
		p Point
	}
	clusters := make([][]Point, len(kMeans.avgPoints))
	resultCh := make(chan BakPoint, len(kMeans.points))
	for _, point := range kMeans.points {
		computeMinDistanceForSignlePoint := func(point Point, avgPoints []Point, distanceFunc func(p1, p2 Point) float64, ch chan BakPoint) {
			minDistance := struct {
				d float64
				i int
			}{
				d: math.MaxFloat64,
				i: -1,
			}
			for i, avgPoint := range avgPoints {
				if d := distanceFunc(avgPoint, point); d < minDistance.d {
					minDistance.d = d
					minDistance.i = i
				}
			}
			ch <- BakPoint{p: point, i: minDistance.i}
		}
		go computeMinDistanceForSignlePoint(point, kMeans.avgPoints, kMeans.distanceFunc, resultCh)
	}

	for i := 0; i < len(kMeans.points); i++ {
		result := <-resultCh
		clusters[result.i] = append(clusters[result.i], result.p)
	}

	close(resultCh)
	return clusters
}

func (kMeans *KMeans) updateAvgPoints(clusters [][]Point) []Point {
	centroids := make([]Point, kMeans.k)
	for i, cluster := range clusters {
		sumX, sumY := 0.0, 0.0
		for _, point := range cluster {
			sumX += point.X
			sumY += point.Y
		}
		centroids[i].X = sumX / float64(len(cluster))
		centroids[i].Y = sumY / float64(len(cluster))
	}
	return centroids
}
相关推荐
我只会发热4 分钟前
Java SE 与 Java EE:基础与进阶的探索之旅
java·开发语言·java-ee
LZXCyrus5 分钟前
【杂记】vLLM如何指定GPU单卡/多卡离线推理
人工智能·经验分享·python·深度学习·语言模型·llm·vllm
一直学习永不止步14 分钟前
LeetCode题练习与总结:最长回文串--409
java·数据结构·算法·leetcode·字符串·贪心·哈希表
懷淰メ14 分钟前
PyQt飞机大战游戏(附下载地址)
开发语言·python·qt·游戏·pyqt·游戏开发·pyqt5
我感觉。23 分钟前
【机器学习chp4】特征工程
人工智能·机器学习·主成分分析·特征工程
hummhumm28 分钟前
第 22 章 - Go语言 测试与基准测试
java·大数据·开发语言·前端·python·golang·log4j
YRr YRr32 分钟前
深度学习神经网络中的优化器的使用
人工智能·深度学习·神经网络
DieYoung_Alive32 分钟前
一篇文章了解机器学习(下)
人工智能·机器学习
夏沫的梦33 分钟前
生成式AI对产业的影响与冲击
人工智能·aigc
宁静@星空34 分钟前
006-自定义枚举注解
java·开发语言