聚類算法Kmeans/K-均值算法

jopen 9年前發布 | 54K 次閱讀 算法

Kmeans是最簡單的聚類算法之一,但是運用十分廣泛,最近看到別人找實習筆試時有考到Kmeans,故復習一下順手整理成一篇筆記。Kmeans的目的是:把n個樣本點劃分到k個類簇中,使得每個點都屬于離它最近的質心對應的類簇,以之作為聚類的標準。

前記

        Kmeans是最簡單的聚類算法之一,但是運用十分廣泛,最近看到別人找實習筆試時有考到Kmeans,故復習一下順手整理成一篇筆記。Kmeans的目標是:把n 個樣本點劃分到k 個類簇中,使得每個點都屬于離它最近的質心對應的類簇,以之作為聚類的標準。質心,是指一個類簇內部所有樣本點的均值

算法描述

Step 1. 從數據集中隨機選取K個點作為初始質心
        將每個點指派到最近的質心,形成k個類簇
Step 2. repeat
            重新計算各個類簇的質心(即類內部點的均值)
            重新將每個點指派到最近的質心,形成k個類簇
        until    質心不再波動

        例如下圖的樣本集,我們目標是分成3個類簇,初始隨機選擇的3個質心比較集中,但是迭代4次之后,質心趨于穩定,并將樣本集分為3部分。

聚類算法Kmeans/K-均值算法

        Kmeans算法,對于距離度量可以使用余弦相似度,也可以使用歐式距離或其它標準;質心,是指一個類簇內部所有樣本點的均值;隨機初始化的質心,當隨機效果不理想時,Kmeans算法的迭代次數變多。Kmeans算法思想比較簡單,但實用。

代碼實現

package kmeans;

public class Point {
    public double[] x;    // 特征維度
    public int len_arr;    // 特征維數
    public boolean isSample = false;    // True判斷是數據集的點,False是第二次kmenas所計算得來的質心
    public int id;    // 質心分配的id=0
    public String text;    // 用于描述鳶尾花種類

    public Point(double[] x, int len_arr, boolean isSample, int id) {
        this.x = x;
        this.len_arr = len_arr;
        this.isSample = isSample;
        this.id = id;
    }

    // 計算歐氏距離
    public double Distance(Point other) {
        double sum = 0;

        for (int i = 0; i < len_arr; i++) {
            sum += Math.pow(x[i] - other.x[i], 2);
        }
        sum = Math.sqrt(sum);

        return sum;
    }

    // 以下兩個方法用于數據結構Set, 第一次kmeans生成k個隨機點時用到
    @Override
    public boolean equals(Object other) {
        if (other.getClass() != Point.class) {
            return false;
        }
        return id == ((Point) other).id;
    }

    @Override
    public int hashCode() {
        return id;
    }
}
package kmeans;

import java.util.*;

public class Cluster {
    public int id;    // 簇id
    public Point center;    // 簇質心
    public List<Point> members = new ArrayList<>();    // 簇中成員(數據集點)

    public Cluster(int id, Point center) {
        this.id = id;
        this.center = center;
    }

    @Override
    public boolean equals(Object o) {
        if (o.getClass() != Cluster.class) {
            return false;
        }
        return id == ((Cluster) o).id;
    }
}
package kmeans;

import java.util.*;

public class Kmeans {
    public List<Point> samples;    // 數據集點
    public List<Cluster> clusters = new ArrayList<>(); // 存放聚類類簇結果
    public int k;    // 聚類個數
    public int arr_len;    // 數據集點特征維數
    public int steps;    // 最大迭代次數

    public Kmeans(List<Point> samples, int k, int arr_len, int steps) {
        this.samples = samples;
        this.k = k;
        this.arr_len = arr_len;
        this.steps = steps;
    }

    public void run() {
        FirstStep();    // 算法Step 1
        double oldDist = Loss();    // 計算各個類簇內點到質心的距離和
        double newDist = 0;
        for (int i = 0; i < steps; i++) {
            SecondStep();    // 算法Step 2
            newDist = Loss();
            if (oldDist - newDist < 0.01) {    // 如果質心不再變化,則停止學習
                break;
            }
            System.out.println("Step " + i + ":" + (oldDist - newDist));
            oldDist = newDist;
        }
        
        // 打印結果
        for (int i = 0; i < clusters.size(); i++) {
            System.out.println("第" + i + "個簇:");
            for (Point p : clusters.get(i).members) {
                if (!p.isSample) {
                    continue;
                }
                System.out.print("(");
                for (int xi = 0; xi < p.x.length; xi++) {
                    if (xi != 0) {
                        System.out.print(",");
                    }
                    System.out.print(p.x[xi]);
                }
                System.out.print(")");
                System.out.println("\t" + p.text);
            }
        }
    }

    public void FirstStep() {    // 算法Step 1
        Set<Point> centers = new HashSet<>();    // 從樣本數據集中隨機選取k個不重復的質心
        int id = 0;    // 類簇id
        while (centers.size() < k) {
            Random r = new Random();    // 隨機選取樣本數據集的數據下標
            int ti = r.nextInt(samples.size()) % samples.size();
            if (centers.contains(samples.get(ti))) {
                continue;
            }
            centers.add(samples.get(ti));
            Cluster clu = new Cluster(id++, samples.get(ti));
            clusters.add(clu);
        }

        Classify();    // 開始根據k個質心進行聚類
    }

    public void SecondStep() {    // 算法Step 2
        List<Cluster> newClusters = new ArrayList<>();
        for (Cluster clu : clusters) {
            double[] tx = new double[arr_len];
            for (Point p : clu.members) {
                for (int i = 0; i < arr_len; i++) {
                    tx[i] += p.x[i];
                }
            }
            for (int i = 0; i < arr_len; i++) {
                tx[i] /= clu.members.size();
            }    // 重新在各個類簇內部計算新的質心
            Point newCenter = new Point(tx, arr_len, false, 0);
            Cluster newClu = new Cluster(clu.id, newCenter);
            newClusters.add(newClu);
        }
        clusters.clear();
        clusters = newClusters;

        Classify();    // 根據新的質心重新聚類
    }

    public void Classify() {    // 聚類步驟,將各個點分配到距離最近的質心所在的類簇
        for (int i = 0; i < samples.size(); i++) {
            double mindistance = Double.MAX_VALUE;
            int clu_Id = -1;
            for (Cluster clu : clusters) {
                if (samples.get(i).Distance(clu.center) < mindistance) {
                    mindistance = samples.get(i).Distance(clu.center);
                    clu_Id = clu.id;
                }
            }

            for (int j = 0; j < clusters.size(); j++) {
                if (clusters.get(j).id == clu_Id) {
                    clusters.get(j).members.add(samples.get(i));
                    break;
                }
            }
        }
    }

    public double Loss() {    // 計算類簇內部各個點到質心的距離
        double sum = 0;

        for (Cluster clu : clusters) {
            for (Point p : clu.members) {
                sum += p.Distance(clu.center);
            }
        }

        return sum;
    }
}
package kmeans;

import java.util.*;

public class Keyven {
    public static void main(String[] args) {
        Scanner input = new Scanner(System.in);

        int n = input.nextInt();
        int arr_len = input.nextInt();
        List<Point> samples = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            double[] x = new double[arr_len];
            for (int j = 0; j < arr_len; j++) {
                x[j] = input.nextDouble();
            }
            String text = input.nextLine();
            Point p = new Point(x, arr_len, true, i + 1);
            p.text = text;
            samples.add(p);
        }
        Kmeans km = new Kmeans(samples, 3, arr_len, 1000);
        km.run();

        input.close();
    }
}

實驗效果

聚類算法Kmeans/K-均值算法

        鳶尾花的數據集下載:http://archive.ics.uci.edu/ml/

算法分析

(1)離群點的處理:離群點一般稱為噪音,離群點有可能影響類簇的發現,導致實驗效果不合理,因此在進行Kmeans之前發現并提出離群點是有必要的。

(2)初始質心的選取:初始質心的隨機選取有可能出現過度集中的情況,導致迭代次數增多,這時可以使用Kmeans++來解決這個問題,Kmeans++算法步驟如下圖:

聚類算法Kmeans/K-均值算法

也可以使用另外一種方法:隨機地選擇第一個點,或取所有點的質心作為第一個點。然后,對于每個后繼初始質心,選擇離已經選取過的初始質心最遠的點。使用這種方法,確保了選擇的初始質心不僅是隨機的,而且是散開的。但是,這種方法可能選中離群點。此外,求離當前初始質心集最遠的點開銷也非常大。

(3)算法終止條件:一般是目標函數達到最優或者達到最大的迭代次數即可終止。對于不同的距離度量,目標函數往往不同。當采用歐式距離時,目標函數一般為最小化對象到其簇質心的距離的平方和,如下:

聚類算法Kmeans/K-均值算法

當采用余弦相似度時,目標函數一般為最大化對象到其簇質心的余弦相似度和,如下:

聚類算法Kmeans/K-均值算法

(4)K值得選取:Kmeans算法的聚類個數K 值是由用戶設定的,因為一開始我們并不知道數據集的分布,Kmeans又不像EM算法那樣自動學習聚類成K 個類簇。為解決這個問題,可以將Kmeans與層次聚類結合,首先采用層次聚類算法粗略決定聚類個數,并找到初始聚類,然后用Kmeans來優化聚類結果。

擴展

        其它聚類算法:譜聚類、層次聚類,等。這里僅簡單地介紹層次聚類

        層次聚類,是一種很直觀的算法。顧名思義就是要一層一層地進行聚類,可以從下而上地把小的cluster合并聚集,也可以從上而下地將大的cluster進行分割,一般采用從下而上地聚類。
        從下而上地合并cluster,就是每次找到距離最短的兩個cluster,然后進行合并成一個大的cluster,直到全部合并為一個cluster。整個過程就是建立一個樹結構,類似于下圖。

聚類算法Kmeans/K-均值算法

        那么,如何判斷兩個cluster之間的距離呢?一開始每個數據點獨自作為一個類,它們的距離就是這兩個點之間的距離。而對于包含不止一個數據點的cluster,就可以選擇多種方法了,最常用的就是average-linkage ,這種方法就是把兩個集合中的點兩兩的距離全部放在一起求一個平均值。

        只要得到了上面那樣的聚類樹,想要分多少個cluster都可以直接根據樹結構來得到結果。

后記

        注意,K-means算法與KNN算法沒有關系,K-means算法是一種聚類算法,而KNN(K近鄰算法)是一種分類算法,下面舉一個例子來說明KNN算法。假如手頭有一堆已經標記好分類的數據點集,新進來一個點,需要我們預測其類別,我們可以取該點的k 個鄰居(距離該點最近的k 個點),如果這k 個鄰居點大多數屬于某一個類別C,則我們預測該點很大可能也屬于類別C。例如下圖中的黑點為預測點,取其7個鄰居點,黃色居多,利用極大似然估計,我們可以認為黑色點屬于黃色。

        KNN算法可以使用Kd樹來實現,具體請參考《統計機器學習 · 李航 著》,這里有一篇Kd-Tree的博文:Kd Tree算法原理和開源實現代碼

聚類算法Kmeans/K-均值算法

Reference

數據挖掘-聚類-K-means算法Java實現

簡單之美Kmeans

基本Kmeans算法介紹及其實現

聚類算法實踐(一)——層次聚類、K-means聚類

聚類(2)——層次聚類 Hierarchical Clustering

最流行的4個機器學習數據集

kmeans python版


 本文由用戶 jopen 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
 轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
 本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!