【机器学习】K-Means 算法

K-Means 算法是常用的聚类算法。

其作用是在一个高维空间,将空间中的向量按照他们的分布划分为 K 类。

具体的做法是:

  1. 首先使用随机数随机生成 K 个初始值作为中心。
  2. 之后计算空间中的每个向量和这 K 个中心的距离(一般为欧氏距离)。那么就将空间中的所有向量划分成了 K 类。
  3. 之后计算每一类的均值,并将当前均值作为 K 个类的中心。
  4. 再计算每个向量和这新的 K 个中心的距离,重复上述步骤 2 和 3 ,直到没有新的点被重新划为新的类,也就是类中心位置不再发生变化。

下面使用Java 代码实现了当前算法:

java 复制代码
import java.util.*;
import java.util.stream.Collectors;

/**
 * 数据点类
 */
class Point {
    private double x;
    private double y;
    private int clusterId;
    
    public Point(double x, double y) {
        this.x = x;
        this.y = y;
        this.clusterId = -1; // 初始未分配
    }
    
    public double getX() { return x; }
    public double getY() { return y; }
    public int getClusterId() { return clusterId; }
    public void setClusterId(int clusterId) { this.clusterId = clusterId; }
    
    /**
     * 计算两点之间的欧氏距离
     */
    public double distanceTo(Point other) {
        double dx = this.x - other.x;
        double dy = this.y - other.y;
        return Math.sqrt(dx * dx + dy * dy);
    }
    
    @Override
    public String toString() {
        return String.format("(%.2f, %.2f) -> Cluster %d", x, y, clusterId);
    }
}

/**
 * KMeans 聚类算法实现
 */
public class KMeans {
    private List<Point> points;          // 所有数据点
    private List<Point> centroids;       // 聚类中心
    private int k;                       // 簇的数量
    private int maxIterations;           // 最大迭代次数
    
    public KMeans(int k, int maxIterations) {
        this.k = k;
        this.maxIterations = maxIterations;
        this.points = new ArrayList<>();
        this.centroids = new ArrayList<>();
    }
    
    /**
     * 添加数据点
     */
    public void addPoint(Point point) {
        points.add(point);
    }
    
    /**
     * 添加多个数据点
     */
    public void addPoints(List<Point> points) {
        this.points.addAll(points);
    }
    
    /**
     * 随机初始化聚类中心
     */
    private void initializeCentroids() {
        centroids.clear();
        List<Point> shuffledPoints = new ArrayList<>(points);
        Collections.shuffle(shuffledPoints);
        
        for (int i = 0; i < k && i < shuffledPoints.size(); i++) {
            Point centroid = new Point(shuffledPoints.get(i).getX(), 
                                      shuffledPoints.get(i).getY());
            centroids.add(centroid);
        }
    }
    
    /**
     * 将点分配到最近的聚类中心
     */
    private void assignPointsToClusters() {
        for (Point point : points) {
            double minDistance = Double.MAX_VALUE;
            int closestCentroidId = -1;
            
            for (int i = 0; i < centroids.size(); i++) {
                double distance = point.distanceTo(centroids.get(i));
                if (distance < minDistance) {
                    minDistance = distance;
                    closestCentroidId = i;
                }
            }
            
            point.setClusterId(closestCentroidId);
        }
    }
    
    /**
     * 重新计算聚类中心
     * @return 中心点是否变化
     */
    private boolean updateCentroids() {
        boolean changed = false;
        
        for (int i = 0; i < k; i++) {
            final int clusterId = i;
            // 获取属于当前簇的所有点
            List<Point> clusterPoints = points.stream()
                .filter(p -> p.getClusterId() == clusterId)
                .collect(Collectors.toList());
            
            if (clusterPoints.isEmpty()) {
                continue; // 避免空簇
            }
            
            // 计算新的中心点
            double sumX = 0, sumY = 0;
            for (Point p : clusterPoints) {
                sumX += p.getX();
                sumY += p.getY();
            }
            
            double newX = sumX / clusterPoints.size();
            double newY = sumY / clusterPoints.size();
            
            // 检查中心点是否变化
            if (Math.abs(centroids.get(i).getX() - newX) > 0.001 ||
                Math.abs(centroids.get(i).getY() - newY) > 0.001) {
                changed = true;
            }
            
            centroids.set(i, new Point(newX, newY));
        }
        
        return changed;
    }
    
    /**
     * 执行 KMeans 聚类
     */
    public void fit() {
        if (points.size() < k) {
            throw new IllegalArgumentException("数据点数量不能少于簇的数量");
        }
        
        // 1. 初始化中心点
        initializeCentroids();
        
        // 2. 迭代优化
        for (int iteration = 0; iteration < maxIterations; iteration++) {
            // 分配点到最近的簇
            assignPointsToClusters();
            
            // 更新中心点
            boolean centroidsChanged = updateCentroids();
            
            // 如果中心点不再变化,提前结束
            if (!centroidsChanged) {
                System.out.println("在第 " + (iteration + 1) + " 次迭代后收敛");
                break;
            }
            
            if (iteration == maxIterations - 1) {
                System.out.println("达到最大迭代次数: " + maxIterations);
            }
        }
    }
    
    /**
     * 计算 SSE(误差平方和)
     */
    public double calculateSSE() {
        double sse = 0.0;
        
        for (Point point : points) {
            Point centroid = centroids.get(point.getClusterId());
            sse += Math.pow(point.distanceTo(centroid), 2);
        }
        
        return sse;
    }
    
    /**
     * 打印聚类结果
     */
    public void printResults() {
        System.out.println("\n=== KMeans 聚类结果 ===");
        System.out.println("簇数量 K = " + k);
        System.out.println("SSE = " + String.format("%.4f", calculateSSE()));
        
        System.out.println("\n聚类中心:");
        for (int i = 0; i < centroids.size(); i++) {
            System.out.println("簇 " + i + " 中心: " + centroids.get(i));
        }
        
        System.out.println("\n数据点分配:");
        Map<Integer, List<Point>> clusters = points.stream()
            .collect(Collectors.groupingBy(Point::getClusterId));
        
        for (Map.Entry<Integer, List<Point>> entry : clusters.entrySet()) {
            System.out.println("\n簇 " + entry.getKey() + " (" + entry.getValue().size() + " 个点):");
            for (Point p : entry.getValue()) {
                System.out.println("  " + p);
            }
        }
    }
    
    /**
     * 获取聚类结果
     */
    public List<Point> getPoints() {
        return points;
    }
    
    public List<Point> getCentroids() {
        return centroids;
    }
    
    /**
     * 主函数 - 示例使用
     */
    public static void main(String[] args) {
        // 创建 KMeans 实例
        KMeans kmeans = new KMeans(3, 100);
        
        // 生成示例数据(三个明显的簇)
        Random rand = new Random(42);
        List<Point> sampleData = new ArrayList<>();
        
        // 簇1:中心在 (2, 2)
        for (int i = 0; i < 20; i++) {
            double x = 2 + rand.nextGaussian() * 0.5;
            double y = 2 + rand.nextGaussian() * 0.5;
            sampleData.add(new Point(x, y));
        }
        
        // 簇2:中心在 (8, 2)
        for (int i = 0; i < 20; i++) {
            double x = 8 + rand.nextGaussian() * 0.5;
            double y = 2 + rand.nextGaussian() * 0.5;
            sampleData.add(new Point(x, y));
        }
        
        // 簇3:中心在 (5, 8)
        for (int i = 0; i < 20; i++) {
            double x = 5 + rand.nextGaussian() * 0.5;
            double y = 8 + rand.nextGaussian() * 0.5;
            sampleData.add(new Point(x, y));
        }
        
        // 添加数据并执行聚类
        kmeans.addPoints(sampleData);
        kmeans.fit();
        
        // 打印结果
        kmeans.printResults();
    }
}
相关推荐
A923A2 小时前
【洛谷刷题 | 第十天】
算法·洛谷·sprintf·sscanf
Mr_Xuhhh2 小时前
LeetCode 热题 100 刷题笔记:数组与排列的经典解法
数据结构·算法·leetcode
老四啊laosi2 小时前
[双指针] 3. 力扣--快乐数
算法·leetcode·快慢指针
rit84324992 小时前
利用随机有限集(RFS)理论结合ILQR和MPC控制蜂群的MATLAB实现
算法·matlab
会编程的土豆2 小时前
leetcode hot 100 之哈希
算法·leetcode·哈希算法
秋天的一阵风2 小时前
【LeetCode 刷题系列|第 3 篇】详解大数相加:从模拟竖式到简洁写法的优化之路🔢
前端·算法·面试
qwehjk20082 小时前
分布式计算C++库
开发语言·c++·算法
m0_716765232 小时前
C++提高编程--仿函数、常用遍历算法(for_each、transform)详解
java·开发语言·c++·经验分享·算法·青少年编程·visual studio
寻寻觅觅☆2 小时前
东华OJ-基础题-59-倒数数列(C++)
开发语言·c++·算法