K最近鄰(KNN)算法的java實現
1.急切學習與懶惰學習
急切學習:在給定訓練元組之后、接收到測試元組之前就構造好泛化(即分類)模型。
屬于急切學習的算法有:決策樹、貝葉斯、基于規則的分類、后向傳播分類、SVM和基于關聯規則挖掘的分類等等。
懶惰學習:直至給定一個測試元組才開始構造泛化模型,也稱為基于實例的學習法。
屬于急切學習的算法有:KNN分類、基于案例的推理分類。
2.KNN的優缺點
優點:原理簡單,實現起來比較方便。支持增量學習。能對超多邊形的復雜決策空間建模。
缺點:計算開銷大,需要有效的存儲技術和并行硬件的支撐。
3.KNN算法原理
基于類比學習,通過比較訓練元組和測試元組的相似度來學習。
將訓練元組和測試元組看作是n維(若元組有n的屬性)空間內的點,給定一條測試元組,搜索n維空間,找出與測試
元組最相近的k個點(即訓練元組),最后取這k個點中的多數類作為測試元組的類別。
相近的度量方法:用空間內兩個點的距離來度量。距離越大,表示兩個點越不相似。
距離的選擇:可采用歐幾里得距離、曼哈頓距離或其它距離度量。多采用歐幾里得距離,簡單!
4.KNN算法中的細節處理
- 數值屬性規范化:將數值屬性規范到0-1區間以便于計算,也可防止大數值型屬性對分類的主導作用。
可選的方法有:v' = (v - vmin)/ (vmax - vmin),當然也可以采用其它的規范化方法
- 比較的屬性是分類類型而不是數值類型的:同則差為0,異則差為1.
有時候可以作更為精確的處理,比如黑色與白色的差肯定要大于灰色與白色的差。
- 缺失值的處理:取最大的可能差,對于分類屬性,如果屬性A的一個或兩個對應值丟失,則取差值為1;
如果A是數值屬性,若兩個比較的元組A屬性值均缺失,則取差值為1,若只有一個缺失,另一個值為v,
則取差值為|1-v|和|0-v|中的最大值
- 確定K的值:通過實驗確定。進行若干次實驗,取分類誤差率最小的k值。
- 對噪聲數據或不相關屬性的處理:對屬性賦予相關性權重w,w越大說明屬性對分類的影響越相關。對噪聲數據可以將所在
的元組直接cut掉。
5.KNN算法流程
- 準備數據,對數據進行預處理
- 選用合適的數據結構存儲訓練數據和測試元組
- 設定參數,如k
- 維護一個大小為k的的按距離由大到小的優先級隊列,用于存儲最近鄰訓練元組
- 隨機從訓練元組中選取k個元組作為初始的最近鄰元組,分別計算測試元組到這k個元組的距離,將訓練元組標號和距離存入優先級隊列
- 遍歷訓練元組集,計算當前訓練元組與測試元組的距離,將所得距離L與優先級隊列中的最大距離Lmax進行比較。若L>=Lmax,則舍棄該元組,遍歷下一個元組。若L < Lmax,刪除優先級隊列中最大距離的元組,將當前訓練元組存入優先級隊列。
- 遍歷完畢,計算優先級隊列中k個元組的多數類,并將其作為測試元組的類別。
- 測試元組集測試完畢后計算誤差率,繼續設定不同的k值重新進行訓練,最后取誤差率最小的k值。
6.KNN算法的改進策略
- 將存儲的訓練元組預先排序并安排在搜索樹中(如何排序有待研究)
- 并行實現
- 部分距離計算,取n個屬性的“子集”計算出部分距離,若超過設定的閾值則停止對當前元組作進一步計算。轉向下一個元組。
- 剪枝或精簡:刪除證明是“無用的”元組。
7.KNN算法java實現
本算法只適合學習使用,可以大致了解一下KNN算法的原理。
算法作了如下的假定與簡化處理:
1.小規模數據集
2.假設所有數據及類別都是數值類型的
3.直接根據數據規模設定了k值
4.對原訓練集進行測試
KNN實現代碼如下:
package KNN; /** * KNN結點類,用來存儲最近鄰的k個元組相關的信息 * @author Rowen * @qq 443773264 * @mail luowen3405@163.com * @blog blog.csdn.net/luowen3405 * @data 2011.03.25 */ public class KNNNode { private int index; // 元組標號 private double distance; // 與測試元組的距離 private String c; // 所屬類別 public KNNNode(int index, double distance, String c) { super(); this.index = index; this.distance = distance; this.c = c; } public int getIndex() { return index; } public void setIndex(int index) { this.index = index; } public double getDistance() { return distance; } public void setDistance(double distance) { this.distance = distance; } public String getC() { return c; } public void setC(String c) { this.c = c; } }
package KNN; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.PriorityQueue; /** * KNN算法主體類 * @author Rowen * @qq 443773264 * @mail luowen3405@163.com * @blog blog.csdn.net/luowen3405 * @data 2011.03.25 */ public class KNN { /** * 設置優先級隊列的比較函數,距離越大,優先級越高 */ private Comparator<KNNNode> comparator = new Comparator<KNNNode>() { public int compare(KNNNode o1, KNNNode o2) { if (o1.getDistance() >= o2.getDistance()) { return 1; } else { return 0; } } }; /** * 獲取K個不同的隨機數 * @param k 隨機數的個數 * @param max 隨機數最大的范圍 * @return 生成的隨機數數組 */ public List<Integer> getRandKNum(int k, int max) { List<Integer> rand = new ArrayList<Integer>(k); for (int i = 0; i < k; i++) { int temp = (int) (Math.random() * max); if (!rand.contains(temp)) { rand.add(temp); } else { i--; } } return rand; } /** * 計算測試元組與訓練元組之前的距離 * @param d1 測試元組 * @param d2 訓練元組 * @return 距離值 */ public double calDistance(List<Double> d1, List<Double> d2) { double distance = 0.00; for (int i = 0; i < d1.size(); i++) { distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i)); } return distance; } /** * 執行KNN算法,獲取測試元組的類別 * @param datas 訓練數據集 * @param testData 測試元組 * @param k 設定的K值 * @return 測試元組的類別 */ public String knn(List<List<Double>> datas, List<Double> testData, int k) { PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator); List<Integer> randNum = getRandKNum(k, datas.size()); for (int i = 0; i < k; i++) { int index = randNum.get(i); List<Double> currData = datas.get(index); String c = currData.get(currData.size() - 1).toString(); KNNNode node = new KNNNode(index, calDistance(testData, currData), c); pq.add(node); } for (int i = 0; i < datas.size(); i++) { List<Double> t = datas.get(i); double distance = calDistance(testData, t); KNNNode top = pq.peek(); if (top.getDistance() > distance) { pq.remove(); pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString())); } } return getMostClass(pq); } /** * 獲取所得到的k個最近鄰元組的多數類 * @param pq 存儲k個最近近鄰元組的優先級隊列 * @return 多數類的名稱 */ private String getMostClass(PriorityQueue<KNNNode> pq) { Map<String, Integer> classCount = new HashMap<String, Integer>(); for (int i = 0; i < pq.size(); i++) { KNNNode node = pq.remove(); String c = node.getC(); if (classCount.containsKey(c)) { classCount.put(c, classCount.get(c) + 1); } else { classCount.put(c, 1); } } int maxIndex = -1; int maxCount = 0; Object[] classes = classCount.keySet().toArray(); for (int i = 0; i < classes.length; i++) { if (classCount.get(classes[i]) > maxCount) { maxIndex = i; maxCount = classCount.get(classes[i]); } } return classes[maxIndex].toString(); } }
package KNN; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.util.ArrayList; import java.util.List; /** * KNN算法測試類 * @author Rowen * @qq 443773264 * @mail luowen3405@163.com * @blog blog.csdn.net/luowen3405 * @data 2011.03.25 */ public class TestKNN { /** * 從數據文件中讀取數據 * @param datas 存儲數據的集合對象 * @param path 數據文件的路徑 */ public void read(List<List<Double>> datas, String path){ try { BufferedReader br = new BufferedReader(new FileReader(new File(path))); String data = br.readLine(); List<Double> l = null; while (data != null) { String t[] = data.split(" "); l = new ArrayList<Double>(); for (int i = 0; i < t.length; i++) { l.add(Double.parseDouble(t[i])); } datas.add(l); data = br.readLine(); } } catch (Exception e) { e.printStackTrace(); } } /** * 程序執行入口 * @param args */ public static void main(String[] args) { TestKNN t = new TestKNN(); String datafile = new File("").getAbsolutePath() + File.separator + "datafile"; String testfile = new File("").getAbsolutePath() + File.separator + "testfile"; try { List<List<Double>> datas = new ArrayList<List<Double>>(); List<List<Double>> testDatas = new ArrayList<List<Double>>(); t.read(datas, datafile); t.read(testDatas, testfile); KNN knn = new KNN(); for (int i = 0; i < testDatas.size(); i++) { List<Double> test = testDatas.get(i); System.out.print("測試元組: "); for (int j = 0; j < test.size(); j++) { System.out.print(test.get(j) + " "); } System.out.print("類別為: "); System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3))))); } } catch (Exception e) { e.printStackTrace(); } } }
訓練數據文件:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5
程序運行結果:
測試元組: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 類別為: 1 測試元組: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 類別為: 1 測試元組: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 類別為: 1 測試元組: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 類別為: 0 測試元組: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 類別為: 1 測試元組: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 類別為: 0
由結果可以看出,分類的測試結果是比較準確的!
轉自:http://blog.csdn.net/luowen3405/article/details/6278764
參考:http://blog.csdn.net/xlm289348/article/details/8876353
http://coolshell.cn/articles/8052.html#more-8052