K-Means 算法是常用的聚类算法。
其作用是在一个高维空间,将空间中的向量按照他们的分布划分为 K 类。
具体的做法是:
- 首先使用随机数随机生成 K 个初始值作为中心。
- 之后计算空间中的每个向量和这 K 个中心的距离(一般为欧氏距离)。那么就将空间中的所有向量划分成了 K 类。
- 之后计算每一类的均值,并将当前均值作为 K 个类的中心。
- 再计算每个向量和这新的 K 个中心的距离,重复上述步骤 2 和 3 ,直到没有新的点被重新划为新的类,也就是类中心位置不再发生变化。
下面使用Java 代码实现了当前算法:
java
import java.util.*;
import java.util.stream.Collectors;
/**
* 数据点类
*/
class Point {
private double x;
private double y;
private int clusterId;
public Point(double x, double y) {
this.x = x;
this.y = y;
this.clusterId = -1; // 初始未分配
}
public double getX() { return x; }
public double getY() { return y; }
public int getClusterId() { return clusterId; }
public void setClusterId(int clusterId) { this.clusterId = clusterId; }
/**
* 计算两点之间的欧氏距离
*/
public double distanceTo(Point other) {
double dx = this.x - other.x;
double dy = this.y - other.y;
return Math.sqrt(dx * dx + dy * dy);
}
@Override
public String toString() {
return String.format("(%.2f, %.2f) -> Cluster %d", x, y, clusterId);
}
}
/**
* KMeans 聚类算法实现
*/
public class KMeans {
private List<Point> points; // 所有数据点
private List<Point> centroids; // 聚类中心
private int k; // 簇的数量
private int maxIterations; // 最大迭代次数
public KMeans(int k, int maxIterations) {
this.k = k;
this.maxIterations = maxIterations;
this.points = new ArrayList<>();
this.centroids = new ArrayList<>();
}
/**
* 添加数据点
*/
public void addPoint(Point point) {
points.add(point);
}
/**
* 添加多个数据点
*/
public void addPoints(List<Point> points) {
this.points.addAll(points);
}
/**
* 随机初始化聚类中心
*/
private void initializeCentroids() {
centroids.clear();
List<Point> shuffledPoints = new ArrayList<>(points);
Collections.shuffle(shuffledPoints);
for (int i = 0; i < k && i < shuffledPoints.size(); i++) {
Point centroid = new Point(shuffledPoints.get(i).getX(),
shuffledPoints.get(i).getY());
centroids.add(centroid);
}
}
/**
* 将点分配到最近的聚类中心
*/
private void assignPointsToClusters() {
for (Point point : points) {
double minDistance = Double.MAX_VALUE;
int closestCentroidId = -1;
for (int i = 0; i < centroids.size(); i++) {
double distance = point.distanceTo(centroids.get(i));
if (distance < minDistance) {
minDistance = distance;
closestCentroidId = i;
}
}
point.setClusterId(closestCentroidId);
}
}
/**
* 重新计算聚类中心
* @return 中心点是否变化
*/
private boolean updateCentroids() {
boolean changed = false;
for (int i = 0; i < k; i++) {
final int clusterId = i;
// 获取属于当前簇的所有点
List<Point> clusterPoints = points.stream()
.filter(p -> p.getClusterId() == clusterId)
.collect(Collectors.toList());
if (clusterPoints.isEmpty()) {
continue; // 避免空簇
}
// 计算新的中心点
double sumX = 0, sumY = 0;
for (Point p : clusterPoints) {
sumX += p.getX();
sumY += p.getY();
}
double newX = sumX / clusterPoints.size();
double newY = sumY / clusterPoints.size();
// 检查中心点是否变化
if (Math.abs(centroids.get(i).getX() - newX) > 0.001 ||
Math.abs(centroids.get(i).getY() - newY) > 0.001) {
changed = true;
}
centroids.set(i, new Point(newX, newY));
}
return changed;
}
/**
* 执行 KMeans 聚类
*/
public void fit() {
if (points.size() < k) {
throw new IllegalArgumentException("数据点数量不能少于簇的数量");
}
// 1. 初始化中心点
initializeCentroids();
// 2. 迭代优化
for (int iteration = 0; iteration < maxIterations; iteration++) {
// 分配点到最近的簇
assignPointsToClusters();
// 更新中心点
boolean centroidsChanged = updateCentroids();
// 如果中心点不再变化,提前结束
if (!centroidsChanged) {
System.out.println("在第 " + (iteration + 1) + " 次迭代后收敛");
break;
}
if (iteration == maxIterations - 1) {
System.out.println("达到最大迭代次数: " + maxIterations);
}
}
}
/**
* 计算 SSE(误差平方和)
*/
public double calculateSSE() {
double sse = 0.0;
for (Point point : points) {
Point centroid = centroids.get(point.getClusterId());
sse += Math.pow(point.distanceTo(centroid), 2);
}
return sse;
}
/**
* 打印聚类结果
*/
public void printResults() {
System.out.println("\n=== KMeans 聚类结果 ===");
System.out.println("簇数量 K = " + k);
System.out.println("SSE = " + String.format("%.4f", calculateSSE()));
System.out.println("\n聚类中心:");
for (int i = 0; i < centroids.size(); i++) {
System.out.println("簇 " + i + " 中心: " + centroids.get(i));
}
System.out.println("\n数据点分配:");
Map<Integer, List<Point>> clusters = points.stream()
.collect(Collectors.groupingBy(Point::getClusterId));
for (Map.Entry<Integer, List<Point>> entry : clusters.entrySet()) {
System.out.println("\n簇 " + entry.getKey() + " (" + entry.getValue().size() + " 个点):");
for (Point p : entry.getValue()) {
System.out.println(" " + p);
}
}
}
/**
* 获取聚类结果
*/
public List<Point> getPoints() {
return points;
}
public List<Point> getCentroids() {
return centroids;
}
/**
* 主函数 - 示例使用
*/
public static void main(String[] args) {
// 创建 KMeans 实例
KMeans kmeans = new KMeans(3, 100);
// 生成示例数据(三个明显的簇)
Random rand = new Random(42);
List<Point> sampleData = new ArrayList<>();
// 簇1:中心在 (2, 2)
for (int i = 0; i < 20; i++) {
double x = 2 + rand.nextGaussian() * 0.5;
double y = 2 + rand.nextGaussian() * 0.5;
sampleData.add(new Point(x, y));
}
// 簇2:中心在 (8, 2)
for (int i = 0; i < 20; i++) {
double x = 8 + rand.nextGaussian() * 0.5;
double y = 2 + rand.nextGaussian() * 0.5;
sampleData.add(new Point(x, y));
}
// 簇3:中心在 (5, 8)
for (int i = 0; i < 20; i++) {
double x = 5 + rand.nextGaussian() * 0.5;
double y = 8 + rand.nextGaussian() * 0.5;
sampleData.add(new Point(x, y));
}
// 添加数据并执行聚类
kmeans.addPoints(sampleData);
kmeans.fit();
// 打印结果
kmeans.printResults();
}
}