无聊时写个k-means

采用欧式距离,进行分类,随机n个三维(a,b,c)点,聚类成k类


import java.util.*;

public class KMeans {

    public static void main(String[] args){
        //K-Means ,首先随机生成n个点
        int n = 20;
        int k = 3;
        if (args.length == 2){
            n = Integer.valueOf(args[0]);
            k = Integer.valueOf(args[1]);
        }
        Random random = new Random();
        Point[] points = new Point[n];
        for(int i=0; i<n; i++){
            points[i] = new Point();
            points[i].x = random.nextInt(100);
            points[i].y = random.nextInt(100);
            points[i].z = random.nextInt(100);
        }
        //选取 k 个中心

        Point[] kcenter = new Point[k];
        Set<Integer> set = new HashSet<>();
        for (int i=0; i<k; i++){
            boolean f = false;
            int cindex = -1;
            while(!f){
                cindex = random.nextInt(n);
                f = set.add(cindex);
            }
            kcenter[i] = new Point();
            kcenter[i].x = points[cindex].x;
            kcenter[i].y = points[cindex].y;
            kcenter[i].z = points[cindex].z;
        }

        //聚类
        int kcount = 0;
        List<List<Point>> category = null;
        while(kcount!=k){
            kcount = 0;
            category = new ArrayList<>();
            for (int i=0; i<k; i++){
                category.add(new ArrayList<>());
            }
            for (int i=0; i<n; i++){
                //距离公式 (x1-x2)^2+....
                int distance = Integer.MAX_VALUE;
                int closeTo = 0;
                for (int j=0; j<k; j++){
                    int d = (int) (Math.pow(points[i].x-kcenter[j].x,2)+Math.pow(points[i].y-kcenter[j].y,2)+Math.pow(points[i].z-kcenter[j].z,2));
                    if (d < distance){
                        distance = d;
                        closeTo = j;
                    }
                }
                category.get(closeTo).add(points[i]);
            }
            //更新中心点
            for (int i=0; i<k; i++){
                List<Point> pList = category.get(i);
                int x = 0;
                int y = 0;
                int z = 0;
                for (Point p:pList){
                    x = x+p.x;
                    y = y+p.y;
                    z = z+p.z;
                }
                x = x/pList.size();
                y = y/pList.size();
                z = z/pList.size();
                if (kcenter[i].x == x && kcenter[i].y == y && kcenter[i].z == z){
                    kcount++;
                }
                kcenter[i].x = x;
                kcenter[i].y = y;
                kcenter[i].z = z;
            }
        }

        //输出
        if (category!=null){
            System.out.println("分成 "+k+"类的结果:");
            for ( int i=0; i<category.size(); i++) {
                List<Point> pointList = category.get(i);
                System.out.println("第 "+i+"类:"+pointList.toString());
            }
        }

    }


}
class Point{
    int x;
    int y;
    int z;

    @Override
    public String toString() {
        return "("+x+", "+y+", "+z+")";
    }
}

k-means算法(欧式距离)

相关文章: