數據挖掘-聚類-K-means算法Java實現
K-Means算法是最古老也是應用最廣泛的聚類算法,它使用質心定義原型,質心是一組點的均值,通常該算法用于n維連續空間中的對象。
K-Means算法流程
step1:選擇K個點作為初始質心
step2:repeat
將每個點指派到最近的質心,形成K個簇
重新計算每個簇的質心
until 質心不在變化

我們對每一個步驟都進行分析
step1:選擇K個點作為初始質心
這一步首先要知道K的值,也就是說K是手動設置的,而不是像EM算法那樣自動聚類成n個簇
其次,如何選擇初始質心
最簡單的方式無異于,隨機選取質心了,然后多次運行,取效果最好的那個結果。這個方法,簡單但不見得有效,有很大的可能是得到局部最優。
另一種復雜的方式是,隨機選取一個質心,然后計算離這個質心最遠的樣本點,對于每個后繼質心都選取已經選取過的質心的最遠點。使用這種方式,可以確保質心是隨機的,并且是散開的。
step2:repeat
將每個點指派到最近的質心,形成K個簇
重新計算每個簇的質心
until 質心不在變化
如何定義最近的概念,對于歐式空間中的點,可以使用歐式空間,對于文檔可以用余弦相似性等等。對于給定的數據,可能適應與多種合適的鄰近性度量。
其他問題
離群點的處理
離群點可能過度影響簇的發現,導致簇的最終發布會與我們的預想有較大出入,所以提前發現并剔除離群點是有必要的。
在我的工作中,是利用方差來剔除離群點,結果顯示效果非常好。
簇分裂和簇合并
使用較大的K,往往會使得聚類的結果看上去更加合理,但很多情況下,我們并不想增加簇的個數。
這時可以交替采用簇分裂和簇合并。這種方式可以避開局部極小,并且能夠得到具有期望個數簇的結果。
抽象了點,簇,和距離
Point.class
public class Point {
private double x;
private double y;
private int id;
private boolean beyond;//標識是否屬于樣本public Point(int id, double x, double y) { this.id = id; this.x = x; this.y = y; this.beyond = true; } public Point(int id, double x, double y, boolean beyond) { this.id = id; this.x = x; this.y = y; this.beyond = beyond; } public double getX() { return x; } public double getY() { return y; } public int getId() { return id; } public boolean isBeyond() { return beyond; } @Override public String toString() { return "Point{" + "id=" + id + ", x=" + x + ", y=" + y + '}'; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Point point = (Point) o; if (Double.compare(point.x, x) != 0) return false; if (Double.compare(point.y, y) != 0) return false; return true; } @Override public int hashCode() { int result; long temp; temp = x != +0.0d ? Double.doubleToLongBits(x) : 0L; result = (int) (temp ^ (temp >>> 32)); temp = y != +0.0d ? Double.doubleToLongBits(y) : 0L; result = 31 * result + (int) (temp ^ (temp >>> 32)); return result; } }</pre>Cluster.class <br />
public class Cluster {
private int id;//標識
private Point center;//中心
private Listmembers = new ArrayList ();//成員 public Cluster(int id, Point center) { this.id = id; this.center = center; } public Cluster(int id, Point center, List<point> members) { this.id = id; this.center = center; this.members = members; } public void addPoint(Point newPoint) { if (!members.contains(newPoint)) members.add(newPoint); else throw new IllegalStateException("試圖處理同一個樣本數據!"); } public int getId() { return id; } public Point getCenter() { return center; } public void setCenter(Point center) { this.center = center; } public List<point> getMembers() { return members; } @Override public String toString() { return "Cluster{" + "id=" + id + ", center=" + center + ", members=" + members + "}"; } }</point></point></point></point></pre>抽象的距離,可以具體實現為歐式,曼式或其他距離公式 <br />
public abstract class AbstractDistance { abstract public double getDis(Point p1, Point p2); }點對
public class Distence implements Comparable{
private Point source;
private Point dest;
private double dis;
private AbstractDistance distance;public Distence(Point source, Point dest, AbstractDistance distance) { this.source = source; this.dest = dest; this.distance = distance; dis = distance.getDis(source, dest); } public Point getSource() { return source; } public Point getDest() { return dest; } public double getDis() { return dis; } @Override public int compareTo(Distence o) { if (o.getDis() > dis) return -1; else return 1; } }</distence></pre> <p>核心實現類<br />
</p>
public class KMeansCluster {
private int k;//簇的個數
private int num = 100000;//迭代次數
private Listdatas;//原始樣本集
private String address;//樣本集路徑
private Listdata = new ArrayList ();
private AbstractDistance distance = new AbstractDistance() {
@Override
public double getDis(Point p1, Point p2) {
//歐幾里德距離
return Math.sqrt(Math.pow(p1.getX() - p2.getX(), 2) + Math.pow(p1.getY() - p2.getY(), 2));
}
};public KMeansCluster(int k, int num, String address) { this.k = k; this.num = num; this.address = address; } public KMeansCluster(int k, String address) { this.k = k; this.address = address; } public KMeansCluster(int k, List<double> datas) { this.k = k; this.datas = datas; } public KMeansCluster(int k, int num, List<double> datas) { this.k = k; this.num = num; this.datas = datas; } private void check() { if (k == 0) throw new IllegalArgumentException("k must be the number > 0"); if (address == null && datas == null) throw new IllegalArgumentException("program can't get real data"); } /** * 初始化數據 * * @throws java.io.FileNotFoundException */ public void init() throws FileNotFoundException { check(); //讀取文件,init data //處理原始數據 for (int i = 0, j = datas.size(); i < j; i++) data.add(new Point(i, datas.get(i), 0)); } /** * 第一次隨機選取中心點 * * @return */ public Set<point> chooseCenter() { Set<point> center = new HashSet<point>(); Random ran = new Random(); int roll = 0; while (center.size() < k) { roll = ran.nextInt(data.size()); center.add(data.get(roll)); } return center; } /** * @param center * @return */ public List<cluster> prepare(Set<point> center) { List<cluster> cluster = new ArrayList<cluster>(); Iterator<point> it = center.iterator(); int id = 0; while (it.hasNext()) { Point p = it.next(); if (p.isBeyond()) { Cluster c = new Cluster(id++, p); c.addPoint(p); cluster.add(c); } else cluster.add(new Cluster(id++, p)); } return cluster; } /** * 第一次運算,中心點為樣本值 * * @param center * @param cluster * @return */ public List<cluster> clustering(Set<point> center, List<cluster> cluster) { Point[] p = center.toArray(new Point[0]); TreeSet<distence> distence = new TreeSet<distence>();//存放距離信息 Point source; Point dest; boolean flag = false; for (int i = 0, n = data.size(); i < n; i++) { distence.clear(); for (int j = 0; j < center.size(); j++) { if (center.contains(data.get(i))) break; flag = true; // 計算距離 source = data.get(i); dest = p[j]; distence.add(new Distence(source, dest, distance)); } if (flag == true) { Distence min = distence.first(); for (int m = 0, k = cluster.size(); m < k; m++) { if (cluster.get(m).getCenter().equals(min.getDest())) cluster.get(m).addPoint(min.getSource()); } } flag = false; } return cluster; } /** * 迭代運算,中心點為簇內樣本均值 * * @param cluster * @return */ public List<cluster> cluster(List<cluster> cluster) { // double error; Set<point> lastCenter = new HashSet<point>(); for (int m = 0; m < num; m++) { // error = 0; Set<point> center = new HashSet<point>(); // 重新計算聚類中心 for (int j = 0; j < k; j++) { List<point> ps = cluster.get(j).getMembers(); int size = ps.size(); if (size < 3) { center.add(cluster.get(j).getCenter()); continue; } // 計算距離 double x = 0.0, y = 0.0; for (int k1 = 0; k1 < size; k1++) { x += ps.get(k1).getX(); y += ps.get(k1).getY(); } //得到新的中心點 Point nc = new Point(-1, x / size, y / size, false); center.add(nc); } if (lastCenter.containsAll(center))//中心點不在變化,退出迭代 break; lastCenter = center; // 迭代運算 cluster = clustering(center, prepare(center)); // for (int nz = 0; nz < k; nz++) { // error += cluster.get(nz).getError();//計算誤差 // } } return cluster; } /** * 輸出聚類信息到控制臺 * * @param cs */ public void out2console(List<cluster> cs) { for (int i = 0; i < cs.size(); i++) { System.out.println("No." + (i + 1) + " cluster:"); Cluster c = cs.get(i); List<point> p = c.getMembers(); for (int j = 0; j < p.size(); j++) { System.out.println("\t" + p.get(j).getX() + " "); } System.out.println(); } } }</point></cluster></point></point></point></point></point></cluster></cluster></distence></distence></cluster></point></cluster></point></cluster></cluster></point></cluster></point></point></point></double></double></point></point></double></pre>代碼還沒有仔細優化,執行的效率可能還存在一定的問題
本文由用戶 jopen 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!