數據挖掘-聚類-K-means算法Java實現

jopen 11年前發布 | 119K 次閱讀 算法

K-Means算法是最古老也是應用最廣泛的聚類算法,它使用質心定義原型,質心是一組點的均值,通常該算法用于n維連續空間中的對象。


K-Means算法流程

step1:選擇K個點作為初始質心

step2:repeat

               將每個點指派到最近的質心,形成K個簇

               重新計算每個簇的質心

            until 質心不在變化 


例如下圖的樣本集,初始選擇是三個質心比較集中,但是迭代3次之后,質心趨于穩定,并將樣本集分為3部分
數據挖掘-聚類-K-means算法Java實現

我們對每一個步驟都進行分析

step1:選擇K個點作為初始質心

這一步首先要知道K的值,也就是說K是手動設置的,而不是像EM算法那樣自動聚類成n個簇

其次,如何選擇初始質心

     最簡單的方式無異于,隨機選取質心了,然后多次運行,取效果最好的那個結果。這個方法,簡單但不見得有效,有很大的可能是得到局部最優。

     另一種復雜的方式是,隨機選取一個質心,然后計算離這個質心最遠的樣本點,對于每個后繼質心都選取已經選取過的質心的最遠點。使用這種方式,可以確保質心是隨機的,并且是散開的。


step2:repeat

               將每個點指派到最近的質心,形成K個簇

               重新計算每個簇的質心

            until 質心不在變化 

如何定義最近的概念,對于歐式空間中的點,可以使用歐式空間,對于文檔可以用余弦相似性等等。對于給定的數據,可能適應與多種合適的鄰近性度量。


其他問題

離群點的處理

離群點可能過度影響簇的發現,導致簇的最終發布會與我們的預想有較大出入,所以提前發現并剔除離群點是有必要的。

在我的工作中,是利用方差來剔除離群點,結果顯示效果非常好。


簇分裂和簇合并

使用較大的K,往往會使得聚類的結果看上去更加合理,但很多情況下,我們并不想增加簇的個數。

這時可以交替采用簇分裂和簇合并。這種方式可以避開局部極小,并且能夠得到具有期望個數簇的結果。


貼上代碼java版,以后有時間寫個python版的
抽象了點,簇,和距離
Point.class

public class Point {
private double x;
private double y;
private int id;
private boolean beyond;//標識是否屬于樣本

    public Point(int id, double x, double y) {  
        this.id = id;  
        this.x = x;  
        this.y = y;  
        this.beyond = true;  
    }  

    public Point(int id, double x, double y, boolean beyond) {  
        this.id = id;  
        this.x = x;  
        this.y = y;  
        this.beyond = beyond;  
    }  

    public double getX() {  
        return x;  
    }  

    public double getY() {  
        return y;  
    }  

    public int getId() {  
        return id;  
    }  

    public boolean isBeyond() {  
        return beyond;  
    }  

    @Override  
    public String toString() {  
        return "Point{" +  
                "id=" + id +  
                ", x=" + x +  
                ", y=" + y +  
                '}';  
    }  

    @Override  
    public boolean equals(Object o) {  
        if (this == o) return true;  
        if (o == null || getClass() != o.getClass()) return false;  

        Point point = (Point) o;  

        if (Double.compare(point.x, x) != 0) return false;  
        if (Double.compare(point.y, y) != 0) return false;  

        return true;  
    }  

    @Override  
    public int hashCode() {  
        int result;  
        long temp;  
        temp = x != +0.0d ? Double.doubleToLongBits(x) : 0L;  
        result = (int) (temp ^ (temp >>> 32));  
        temp = y != +0.0d ? Double.doubleToLongBits(y) : 0L;  
        result = 31 * result + (int) (temp ^ (temp >>> 32));  
        return result;  
    }  
}</pre>Cluster.class <br />

public class Cluster {
private int id;//標識
private Point center;//中心
private List members = new ArrayList();//成員

    public Cluster(int id, Point center) {  
        this.id = id;  
        this.center = center;  
    }  

    public Cluster(int id, Point center, List<point> members) {  
        this.id = id;  
        this.center = center;  
        this.members = members;  
    }  

    public void addPoint(Point newPoint) {  
        if (!members.contains(newPoint))  
            members.add(newPoint);  
        else  
            throw new IllegalStateException("試圖處理同一個樣本數據!");  
    }  

    public int getId() {  
        return id;  
    }  

    public Point getCenter() {  
        return center;  
    }  

    public void setCenter(Point center) {  
        this.center = center;  
    }  

    public List<point> getMembers() {  
        return members;  
    }  

    @Override  
    public String toString() {  
        return "Cluster{" +  
                "id=" + id +  
                ", center=" + center +  
                ", members=" + members +  
                "}";  
    }  
}</point></point></point></point></pre>抽象的距離,可以具體實現為歐式,曼式或其他距離公式 <br />

public abstract class AbstractDistance {  
        abstract public double getDis(Point p1, Point p2);  
    }
點對

public class Distence implements Comparable {
private Point source;
private Point dest;
private double dis;
private AbstractDistance distance;

    public Distence(Point source, Point dest, AbstractDistance distance) {  
        this.source = source;  
        this.dest = dest;  
        this.distance = distance;  
        dis = distance.getDis(source, dest);  
    }  

    public Point getSource() {  
        return source;  
    }  

    public Point getDest() {  
        return dest;  
    }  

    public double getDis() {  
        return dis;  
    }  

    @Override  
    public int compareTo(Distence o) {  
        if (o.getDis() > dis)  
            return -1;  
        else  
            return 1;  
    }  
}</distence></pre> <p>核心實現類<br />

</p>

public class KMeansCluster {
private int k;//簇的個數
private int num = 100000;//迭代次數
private List datas;//原始樣本集
private String address;//樣本集路徑
private List data = new ArrayList();
private AbstractDistance distance = new AbstractDistance() {
@Override
public double getDis(Point p1, Point p2) {
//歐幾里德距離
return Math.sqrt(Math.pow(p1.getX() - p2.getX(), 2) + Math.pow(p1.getY() - p2.getY(), 2));
}
};

    public KMeansCluster(int k, int num, String address) {  
        this.k = k;  
        this.num = num;  
        this.address = address;  
    }  

    public KMeansCluster(int k, String address) {  
        this.k = k;  
        this.address = address;  
    }  

    public KMeansCluster(int k, List<double> datas) {  
        this.k = k;  
        this.datas = datas;  
    }  

    public KMeansCluster(int k, int num, List<double> datas) {  
        this.k = k;  
        this.num = num;  
        this.datas = datas;  
    }  

    private void check() {  
        if (k == 0)  
            throw new IllegalArgumentException("k must be the number > 0");  

        if (address == null && datas == null)  
            throw new IllegalArgumentException("program can't get real data");  
    }  

    /** 
     * 初始化數據 
     * 
     * @throws java.io.FileNotFoundException 
     */  
    public void init() throws FileNotFoundException {  
        check();  
        //讀取文件,init data  
        //處理原始數據  
        for (int i = 0, j = datas.size(); i < j; i++)  
            data.add(new Point(i, datas.get(i), 0));  
    }  

    /** 
     * 第一次隨機選取中心點 
     * 
     * @return 
     */  
    public Set<point> chooseCenter() {  
        Set<point> center = new HashSet<point>();  
        Random ran = new Random();  
        int roll = 0;  
        while (center.size() < k) {  
            roll = ran.nextInt(data.size());  
            center.add(data.get(roll));  
        }  
        return center;  
    }  

    /** 
     * @param center 
     * @return 
     */  
    public List<cluster> prepare(Set<point> center) {  
        List<cluster> cluster = new ArrayList<cluster>();  
        Iterator<point> it = center.iterator();  
        int id = 0;  
        while (it.hasNext()) {  
            Point p = it.next();  
            if (p.isBeyond()) {  
                Cluster c = new Cluster(id++, p);  
                c.addPoint(p);  
                cluster.add(c);  
            } else  
                cluster.add(new Cluster(id++, p));  
        }  
        return cluster;  
    }  

    /** 
     * 第一次運算,中心點為樣本值 
     * 
     * @param center 
     * @param cluster 
     * @return 
     */  
    public List<cluster> clustering(Set<point> center, List<cluster> cluster) {  
        Point[] p = center.toArray(new Point[0]);  
        TreeSet<distence> distence = new TreeSet<distence>();//存放距離信息  
        Point source;  
        Point dest;  
        boolean flag = false;  
        for (int i = 0, n = data.size(); i < n; i++) {  
            distence.clear();  
            for (int j = 0; j < center.size(); j++) {  
                if (center.contains(data.get(i)))  
                    break;  

                flag = true;  
                // 計算距離  
                source = data.get(i);  
                dest = p[j];  
                distence.add(new Distence(source, dest, distance));  
            }  
            if (flag == true) {  
                Distence min = distence.first();  
                for (int m = 0, k = cluster.size(); m < k; m++) {  
                    if (cluster.get(m).getCenter().equals(min.getDest()))  
                        cluster.get(m).addPoint(min.getSource());  

                }  
            }  
            flag = false;  
        }  

        return cluster;  
    }  

    /** 
     * 迭代運算,中心點為簇內樣本均值 
     * 
     * @param cluster 
     * @return 
     */  
    public List<cluster> cluster(List<cluster> cluster) {  
//        double error;  
        Set<point> lastCenter = new HashSet<point>();  
        for (int m = 0; m < num; m++) {  
//            error = 0;  
            Set<point> center = new HashSet<point>();  
            // 重新計算聚類中心  
            for (int j = 0; j < k; j++) {  
                List<point> ps = cluster.get(j).getMembers();  
                int size = ps.size();  
                if (size < 3) {  
                    center.add(cluster.get(j).getCenter());  
                    continue;  
                }  
                // 計算距離  
                double x = 0.0, y = 0.0;  
                for (int k1 = 0; k1 < size; k1++) {  
                    x += ps.get(k1).getX();  
                    y += ps.get(k1).getY();  
                }  
                //得到新的中心點  
                Point nc = new Point(-1, x / size, y / size, false);  
                center.add(nc);  
            }  
            if (lastCenter.containsAll(center))//中心點不在變化,退出迭代  
                break;  
            lastCenter = center;  
            // 迭代運算  
            cluster = clustering(center, prepare(center));  
//            for (int nz = 0; nz < k; nz++) {  
//                error += cluster.get(nz).getError();//計算誤差  
//            }  
        }  
        return cluster;  
    }  

    /** 
     * 輸出聚類信息到控制臺 
     * 
     * @param cs 
     */  
    public void out2console(List<cluster> cs) {  
        for (int i = 0; i < cs.size(); i++) {  
            System.out.println("No." + (i + 1) + " cluster:");  
            Cluster c = cs.get(i);  
            List<point> p = c.getMembers();  
            for (int j = 0; j < p.size(); j++) {  
                System.out.println("\t" + p.get(j).getX() + " ");  
            }  
            System.out.println();  
        }  
    }  
}</point></cluster></point></point></point></point></point></cluster></cluster></distence></distence></cluster></point></cluster></point></cluster></cluster></point></cluster></point></point></point></double></double></point></point></double></pre>代碼還沒有仔細優化,執行的效率可能還存在一定的問題
 本文由用戶 jopen 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
 轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
 本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!