K最近鄰(KNN)算法的java實現

jopen 10年前發布 | 127K 次閱讀 算法

1.急切學習與懶惰學習

急切學習:在給定訓練元組之后、接收到測試元組之前就構造好泛化(即分類)模型。

屬于急切學習的算法有:決策樹、貝葉斯、基于規則的分類、后向傳播分類、SVM和基于關聯規則挖掘的分類等等。

 

懶惰學習:直至給定一個測試元組才開始構造泛化模型,也稱為基于實例的學習法。

屬于急切學習的算法有:KNN分類、基于案例的推理分類。

 

2.KNN的優缺點

優點:原理簡單,實現起來比較方便。支持增量學習。能對超多邊形的復雜決策空間建模。

缺點:計算開銷大,需要有效的存儲技術和并行硬件的支撐。

 

3.KNN算法原理

基于類比學習,通過比較訓練元組和測試元組的相似度來學習。

將訓練元組和測試元組看作是n維(若元組有n的屬性)空間內的點,給定一條測試元組,搜索n維空間,找出與測試

元組最相近的k個點(即訓練元組),最后取這k個點中的多數類作為測試元組的類別。

 

相近的度量方法:用空間內兩個點的距離來度量。距離越大,表示兩個點越不相似。

 

距離的選擇:可采用歐幾里得距離、曼哈頓距離或其它距離度量。多采用歐幾里得距離,簡單!

 

4.KNN算法中的細節處理

  • 數值屬性規范化:將數值屬性規范到0-1區間以便于計算,也可防止大數值型屬性對分類的主導作用。

可選的方法有:v' = (v - vmin)/ (vmax - vmin),當然也可以采用其它的規范化方法

  • 比較的屬性是分類類型而不是數值類型的:同則差為0,異則差為1.

有時候可以作更為精確的處理,比如黑色與白色的差肯定要大于灰色與白色的差。

  • 缺失值的處理:取最大的可能差,對于分類屬性,如果屬性A的一個或兩個對應值丟失,則取差值為1;

如果A是數值屬性,若兩個比較的元組A屬性值均缺失,則取差值為1,若只有一個缺失,另一個值為v,

則取差值為|1-v|和|0-v|中的最大值

  • 確定K的值:通過實驗確定。進行若干次實驗,取分類誤差率最小的k值。
  • 對噪聲數據或不相關屬性的處理:對屬性賦予相關性權重w,w越大說明屬性對分類的影響越相關。對噪聲數據可以將所在

的元組直接cut掉。

 

5.KNN算法流程

  • 準備數據,對數據進行預處理
  • 選用合適的數據結構存儲訓練數據和測試元組
  • 設定參數,如k
  • 維護一個大小為k的的按距離由大到小的優先級隊列,用于存儲最近鄰訓練元組
  • 隨機從訓練元組中選取k個元組作為初始的最近鄰元組,分別計算測試元組到這k個元組的距離,將訓練元組標號和距離存入優先級隊列
  • 遍歷訓練元組集,計算當前訓練元組與測試元組的距離,將所得距離L與優先級隊列中的最大距離Lmax進行比較。若L>=Lmax,則舍棄該元組,遍歷下一個元組。若L < Lmax,刪除優先級隊列中最大距離的元組,將當前訓練元組存入優先級隊列。
  • 遍歷完畢,計算優先級隊列中k個元組的多數類,并將其作為測試元組的類別。
  • 測試元組集測試完畢后計算誤差率,繼續設定不同的k值重新進行訓練,最后取誤差率最小的k值。

6.KNN算法的改進策略

  • 將存儲的訓練元組預先排序并安排在搜索樹中(如何排序有待研究)
  • 并行實現
  • 部分距離計算,取n個屬性的“子集”計算出部分距離,若超過設定的閾值則停止對當前元組作進一步計算。轉向下一個元組。
  • 剪枝或精簡:刪除證明是“無用的”元組。

7.KNN算法java實現

 

本算法只適合學習使用,可以大致了解一下KNN算法的原理。

 

算法作了如下的假定與簡化處理:

1.小規模數據集

2.假設所有數據及類別都是數值類型的

3.直接根據數據規模設定了k值

4.對原訓練集進行測試

 

KNN實現代碼如下:

    package KNN;  
    /** 
     * KNN結點類,用來存儲最近鄰的k個元組相關的信息 
     * @author Rowen 
     * @qq 443773264 
     * @mail luowen3405@163.com 
     * @blog blog.csdn.net/luowen3405 
     * @data 2011.03.25 
     */  
    public class KNNNode {  
        private int index; // 元組標號  
        private double distance; // 與測試元組的距離  
        private String c; // 所屬類別  
        public KNNNode(int index, double distance, String c) {  
            super();  
            this.index = index;  
            this.distance = distance;  
            this.c = c;  
        }  


        public int getIndex() {  
            return index;  
        }  
        public void setIndex(int index) {  
            this.index = index;  
        }  
        public double getDistance() {  
            return distance;  
        }  
        public void setDistance(double distance) {  
            this.distance = distance;  
        }  
        public String getC() {  
            return c;  
        }  
        public void setC(String c) {  
            this.c = c;  
        }  
    }  

 

    package KNN;  
    import java.util.ArrayList;  
    import java.util.Comparator;  
    import java.util.HashMap;  
    import java.util.List;  
    import java.util.Map;  
    import java.util.PriorityQueue;  

    /** 
     * KNN算法主體類 
     * @author Rowen 
     * @qq 443773264 
     * @mail luowen3405@163.com 
     * @blog blog.csdn.net/luowen3405 
     * @data 2011.03.25 
     */  
    public class KNN {  
        /** 
         * 設置優先級隊列的比較函數,距離越大,優先級越高 
         */  
        private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {  
            public int compare(KNNNode o1, KNNNode o2) {  
                if (o1.getDistance() >= o2.getDistance()) {  
                    return 1;  
                } else {  
                    return 0;  
                }  
            }  
        };  
        /** 
         * 獲取K個不同的隨機數 
         * @param k 隨機數的個數 
         * @param max 隨機數最大的范圍 
         * @return 生成的隨機數數組 
         */  
        public List<Integer> getRandKNum(int k, int max) {  
            List<Integer> rand = new ArrayList<Integer>(k);  
            for (int i = 0; i < k; i++) {  
                int temp = (int) (Math.random() * max);  
                if (!rand.contains(temp)) {  
                    rand.add(temp);  
                } else {  
                    i--;  
                }  
            }  
            return rand;  
        }  
        /** 
         * 計算測試元組與訓練元組之前的距離 
         * @param d1 測試元組 
         * @param d2 訓練元組 
         * @return 距離值 
         */  
        public double calDistance(List<Double> d1, List<Double> d2) {  
            double distance = 0.00;  
            for (int i = 0; i < d1.size(); i++) {  
                distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));  
            }  
            return distance;  
        }  
        /** 
         * 執行KNN算法,獲取測試元組的類別 
         * @param datas 訓練數據集 
         * @param testData 測試元組 
         * @param k 設定的K值 
         * @return 測試元組的類別 
         */  
        public String knn(List<List<Double>> datas, List<Double> testData, int k) {  
            PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);  
            List<Integer> randNum = getRandKNum(k, datas.size());  
            for (int i = 0; i < k; i++) {  
                int index = randNum.get(i);  
                List<Double> currData = datas.get(index);  
                String c = currData.get(currData.size() - 1).toString();  
                KNNNode node = new KNNNode(index, calDistance(testData, currData), c);  
                pq.add(node);  
            }  
            for (int i = 0; i < datas.size(); i++) {  
                List<Double> t = datas.get(i);  
                double distance = calDistance(testData, t);  
                KNNNode top = pq.peek();  
                if (top.getDistance() > distance) {  
                    pq.remove();  
                    pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));  
                }  
            }  

            return getMostClass(pq);  
        }  
        /** 
         * 獲取所得到的k個最近鄰元組的多數類 
         * @param pq 存儲k個最近近鄰元組的優先級隊列 
         * @return 多數類的名稱 
         */  
        private String getMostClass(PriorityQueue<KNNNode> pq) {  
            Map<String, Integer> classCount = new HashMap<String, Integer>();  
            for (int i = 0; i < pq.size(); i++) {  
                KNNNode node = pq.remove();  
                String c = node.getC();  
                if (classCount.containsKey(c)) {  
                    classCount.put(c, classCount.get(c) + 1);  
                } else {  
                    classCount.put(c, 1);  
                }  
            }  
            int maxIndex = -1;  
            int maxCount = 0;  
            Object[] classes = classCount.keySet().toArray();  
            for (int i = 0; i < classes.length; i++) {  
                if (classCount.get(classes[i]) > maxCount) {  
                    maxIndex = i;  
                    maxCount = classCount.get(classes[i]);  
                }  
            }  
            return classes[maxIndex].toString();  
        }  
    }  

    package KNN;  
    import java.io.BufferedReader;  
    import java.io.File;  
    import java.io.FileReader;  
    import java.util.ArrayList;  
    import java.util.List;  
    /** 
     * KNN算法測試類 
     * @author Rowen 
     * @qq 443773264 
     * @mail luowen3405@163.com 
     * @blog blog.csdn.net/luowen3405 
     * @data 2011.03.25 
     */  
    public class TestKNN {  

        /** 
         * 從數據文件中讀取數據 
         * @param datas 存儲數據的集合對象 
         * @param path 數據文件的路徑 
         */  
        public void read(List<List<Double>> datas, String path){  
            try {  
                BufferedReader br = new BufferedReader(new FileReader(new File(path)));  
                String data = br.readLine();  
                List<Double> l = null;  
                while (data != null) {  
                    String t[] = data.split(" ");  
                    l = new ArrayList<Double>();  
                    for (int i = 0; i < t.length; i++) {  
                        l.add(Double.parseDouble(t[i]));  
                    }  
                    datas.add(l);  
                    data = br.readLine();  
                }  
            } catch (Exception e) {  
                e.printStackTrace();  
            }  
        }  

        /** 
         * 程序執行入口 
         * @param args 
         */  
        public static void main(String[] args) {  
            TestKNN t = new TestKNN();  
            String datafile = new File("").getAbsolutePath() + File.separator + "datafile";  
            String testfile = new File("").getAbsolutePath() + File.separator + "testfile";  
            try {  
                List<List<Double>> datas = new ArrayList<List<Double>>();  
                List<List<Double>> testDatas = new ArrayList<List<Double>>();  
                t.read(datas, datafile);  
                t.read(testDatas, testfile);  
                KNN knn = new KNN();  
                for (int i = 0; i < testDatas.size(); i++) {  
                    List<Double> test = testDatas.get(i);  
                    System.out.print("測試元組: ");  
                    for (int j = 0; j < test.size(); j++) {  
                        System.out.print(test.get(j) + " ");  
                    }  
                    System.out.print("類別為: ");  
                    System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3)))));  
                }  
            } catch (Exception e) {  
                e.printStackTrace();  
            }  
        }  
    }  

 

訓練數據文件:

    1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1  
    1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1  
    1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1  
    1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0  
    1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1  
    1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0  

 

    1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5  
    1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8  
    1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2  
    1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5  
    1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5  
    1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5  

 

程序運行結果:

    測試元組: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 類別為: 1  
    測試元組: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 類別為: 1  
    測試元組: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 類別為: 1  
    測試元組: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 類別為: 0  
    測試元組: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 類別為: 1  
    測試元組: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 類別為: 0  

 

由結果可以看出,分類的測試結果是比較準確的!

 

 

轉自:http://blog.csdn.net/luowen3405/article/details/6278764

 

參考:http://blog.csdn.net/xlm289348/article/details/8876353

http://coolshell.cn/articles/8052.html#more-8052

 本文由用戶 jopen 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
 轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
 本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!