import java.util.ArrayList;
import java.util.List;
import java.util.Random;
class Point {
double x;
double y;
public Point(double x, double y) {
this.x = x;
this.y = y;
}
}
public class KMeansClustering {
private int k;
private List<Point> points;
private List<Point> centroids;
public KMeansClustering(int k, List<Point> points) {
this.k = k;
this.points = points;
centroids = new ArrayList<>();
}
public void run() {
// 初始化质心
initializeCentroids();
boolean converged = false;
while (!converged) {
List<List<Point>> clusters = new ArrayList<>();
for (int i = 0; i < k; i++) {
clusters.add(new ArrayList<>());
}
// 分配点到最近的质心
for (Point point : points) {
int clusterIndex = findClosestCentroid(point);
clusters.get(clusterIndex).add(point);
}
List<Point> newCentroids = new ArrayList<>();
for (List<Point> cluster : clusters) {
if (!cluster.isEmpty()) {
double sumX = 0;
double sumY = 0;
for (Point p : cluster) {
sumX += p.x;
sumY += p.y;
}
newCentroids.add(new Point(sumX / cluster.size(), sumY / cluster.size()));
} else {
// 如果一个簇为空,随机选择一个点作为新质心
Random random = new Random();
int randomIndex = random.nextInt(points.size());
newCentroids.add(points.get(randomIndex));
}
}
converged = true;
for (int i = 0; i < k; i++) {
if (!centroids.get(i).equals(newCentroids.get(i))) {
converged = false;
break;
}
}
centroids = newCentroids;
}
}
private int findClosestCentroid(Point point) {
double minDistance = Double.MAX_VALUE;
int closestCentroidIndex = 0;
for (int i = 0; i < k; i++) {
double distance = distance(point, centroids.get(i));
if (distance < minDistance) {
minDistance = distance;
closestCentroidIndex = i;
}
}
return closestCentroidIndex;
}
private double distance(Point p1, Point p2) {
return Math.sqrt(Math.pow(p1.x - p2.x, 2) + Math.pow(p1.y - p2.y, 2));
}
private void initializeCentroids() {
Random random = new Random();
for (int i = 0; i < k; i++) {
int randomIndex = random.nextInt(points.size());
centroids.add(points.get(randomIndex));
}
}
public List<Point> getCentroids() {
return centroids;
}
public static void main(String[] args) {
List<Point> points = new ArrayList<>();
points.add(new Point(1, 1));
points.add(new Point(1.5, 2));
points.add(new Point(3, 4));
points.add(new Point(5, 7));
points.add(new Point(3.5, 5));
points.add(new Point(4.5, 5));
points.add(new Point(3.5, 4.5));
int k = 3;
KMeansClustering kMeans = new KMeansClustering(k, points);
kMeans.run();
List<Point> centroids = kMeans.getCentroids();
for (int i = 0; i < centroids.size(); i++) {
System.out.println("Centroid " + (i + 1) + ": (" + centroids.get(i).x + ", " + centroids.get(i).y + ")");
}
}
}