K最近鄰(KNN)算法的java實現
1.急切學習與懶惰學習
急切學習:在給定訓練元組之后、接收到測試元組之前就構造好泛化(即分類)模型。
屬于急切學習的算法有:決策樹、貝葉斯、基于規則的分類、后向傳播分類、SVM和基于關聯規則挖掘的分類等等。
懶惰學習:直至給定一個測試元組才開始構造泛化模型,也稱為基于實例的學習法。
屬于急切學習的算法有:KNN分類、基于案例的推理分類。
2.KNN的優缺點
優點:原理簡單,實現起來比較方便。支持增量學習。能對超多邊形的復雜決策空間建模。
缺點:計算開銷大,需要有效的存儲技術和并行硬件的支撐。
3.KNN算法原理
基于類比學習,通過比較訓練元組和測試元組的相似度來學習。
將訓練元組和測試元組看作是n維(若元組有n的屬性)空間內的點,給定一條測試元組,搜索n維空間,找出與測試
元組最相近的k個點(即訓練元組),最后取這k個點中的多數類作為測試元組的類別。
相近的度量方法:用空間內兩個點的距離來度量。距離越大,表示兩個點越不相似。
距離的選擇:可采用歐幾里得距離、曼哈頓距離或其它距離度量。多采用歐幾里得距離,簡單!
4.KNN算法中的細節處理
- 數值屬性規范化:將數值屬性規范到0-1區間以便于計算,也可防止大數值型屬性對分類的主導作用。
可選的方法有:v' = (v - vmin)/ (vmax - vmin),當然也可以采用其它的規范化方法
- 比較的屬性是分類類型而不是數值類型的:同則差為0,異則差為1.
有時候可以作更為精確的處理,比如黑色與白色的差肯定要大于灰色與白色的差。
- 缺失值的處理:取最大的可能差,對于分類屬性,如果屬性A的一個或兩個對應值丟失,則取差值為1;
如果A是數值屬性,若兩個比較的元組A屬性值均缺失,則取差值為1,若只有一個缺失,另一個值為v,
則取差值為|1-v|和|0-v|中的最大值
- 確定K的值:通過實驗確定。進行若干次實驗,取分類誤差率最小的k值。
- 對噪聲數據或不相關屬性的處理:對屬性賦予相關性權重w,w越大說明屬性對分類的影響越相關。對噪聲數據可以將所在
的元組直接cut掉。
5.KNN算法流程
- 準備數據,對數據進行預處理
- 選用合適的數據結構存儲訓練數據和測試元組
- 設定參數,如k
- 維護一個大小為k的的按距離由大到小的優先級隊列,用于存儲最近鄰訓練元組
- 隨機從訓練元組中選取k個元組作為初始的最近鄰元組,分別計算測試元組到這k個元組的距離,將訓練元組標號和距離存入優先級隊列
- 遍歷訓練元組集,計算當前訓練元組與測試元組的距離,將所得距離L與優先級隊列中的最大距離Lmax進行比較。若L>=Lmax,則舍棄該元組,遍歷下一個元組。若L < Lmax,刪除優先級隊列中最大距離的元組,將當前訓練元組存入優先級隊列。
- 遍歷完畢,計算優先級隊列中k個元組的多數類,并將其作為測試元組的類別。
- 測試元組集測試完畢后計算誤差率,繼續設定不同的k值重新進行訓練,最后取誤差率最小的k值。
6.KNN算法的改進策略
- 將存儲的訓練元組預先排序并安排在搜索樹中(如何排序有待研究)
- 并行實現
- 部分距離計算,取n個屬性的“子集”計算出部分距離,若超過設定的閾值則停止對當前元組作進一步計算。轉向下一個元組。
- 剪枝或精簡:刪除證明是“無用的”元組。
7.KNN算法java實現
本算法只適合學習使用,可以大致了解一下KNN算法的原理。
算法作了如下的假定與簡化處理:
1.小規模數據集
2.假設所有數據及類別都是數值類型的
3.直接根據數據規模設定了k值
4.對原訓練集進行測試
KNN實現代碼如下:
package KNN;
/**
* KNN結點類,用來存儲最近鄰的k個元組相關的信息
* @author Rowen
* @qq 443773264
* @mail luowen3405@163.com
* @blog blog.csdn.net/luowen3405
* @data 2011.03.25
*/
public class KNNNode {
private int index; // 元組標號
private double distance; // 與測試元組的距離
private String c; // 所屬類別
public KNNNode(int index, double distance, String c) {
super();
this.index = index;
this.distance = distance;
this.c = c;
}
public int getIndex() {
return index;
}
public void setIndex(int index) {
this.index = index;
}
public double getDistance() {
return distance;
}
public void setDistance(double distance) {
this.distance = distance;
}
public String getC() {
return c;
}
public void setC(String c) {
this.c = c;
}
}
package KNN;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
/**
* KNN算法主體類
* @author Rowen
* @qq 443773264
* @mail luowen3405@163.com
* @blog blog.csdn.net/luowen3405
* @data 2011.03.25
*/
public class KNN {
/**
* 設置優先級隊列的比較函數,距離越大,優先級越高
*/
private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {
public int compare(KNNNode o1, KNNNode o2) {
if (o1.getDistance() >= o2.getDistance()) {
return 1;
} else {
return 0;
}
}
};
/**
* 獲取K個不同的隨機數
* @param k 隨機數的個數
* @param max 隨機數最大的范圍
* @return 生成的隨機數數組
*/
public List<Integer> getRandKNum(int k, int max) {
List<Integer> rand = new ArrayList<Integer>(k);
for (int i = 0; i < k; i++) {
int temp = (int) (Math.random() * max);
if (!rand.contains(temp)) {
rand.add(temp);
} else {
i--;
}
}
return rand;
}
/**
* 計算測試元組與訓練元組之前的距離
* @param d1 測試元組
* @param d2 訓練元組
* @return 距離值
*/
public double calDistance(List<Double> d1, List<Double> d2) {
double distance = 0.00;
for (int i = 0; i < d1.size(); i++) {
distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));
}
return distance;
}
/**
* 執行KNN算法,獲取測試元組的類別
* @param datas 訓練數據集
* @param testData 測試元組
* @param k 設定的K值
* @return 測試元組的類別
*/
public String knn(List<List<Double>> datas, List<Double> testData, int k) {
PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);
List<Integer> randNum = getRandKNum(k, datas.size());
for (int i = 0; i < k; i++) {
int index = randNum.get(i);
List<Double> currData = datas.get(index);
String c = currData.get(currData.size() - 1).toString();
KNNNode node = new KNNNode(index, calDistance(testData, currData), c);
pq.add(node);
}
for (int i = 0; i < datas.size(); i++) {
List<Double> t = datas.get(i);
double distance = calDistance(testData, t);
KNNNode top = pq.peek();
if (top.getDistance() > distance) {
pq.remove();
pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));
}
}
return getMostClass(pq);
}
/**
* 獲取所得到的k個最近鄰元組的多數類
* @param pq 存儲k個最近近鄰元組的優先級隊列
* @return 多數類的名稱
*/
private String getMostClass(PriorityQueue<KNNNode> pq) {
Map<String, Integer> classCount = new HashMap<String, Integer>();
for (int i = 0; i < pq.size(); i++) {
KNNNode node = pq.remove();
String c = node.getC();
if (classCount.containsKey(c)) {
classCount.put(c, classCount.get(c) + 1);
} else {
classCount.put(c, 1);
}
}
int maxIndex = -1;
int maxCount = 0;
Object[] classes = classCount.keySet().toArray();
for (int i = 0; i < classes.length; i++) {
if (classCount.get(classes[i]) > maxCount) {
maxIndex = i;
maxCount = classCount.get(classes[i]);
}
}
return classes[maxIndex].toString();
}
} package KNN;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
/**
* KNN算法測試類
* @author Rowen
* @qq 443773264
* @mail luowen3405@163.com
* @blog blog.csdn.net/luowen3405
* @data 2011.03.25
*/
public class TestKNN {
/**
* 從數據文件中讀取數據
* @param datas 存儲數據的集合對象
* @param path 數據文件的路徑
*/
public void read(List<List<Double>> datas, String path){
try {
BufferedReader br = new BufferedReader(new FileReader(new File(path)));
String data = br.readLine();
List<Double> l = null;
while (data != null) {
String t[] = data.split(" ");
l = new ArrayList<Double>();
for (int i = 0; i < t.length; i++) {
l.add(Double.parseDouble(t[i]));
}
datas.add(l);
data = br.readLine();
}
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 程序執行入口
* @param args
*/
public static void main(String[] args) {
TestKNN t = new TestKNN();
String datafile = new File("").getAbsolutePath() + File.separator + "datafile";
String testfile = new File("").getAbsolutePath() + File.separator + "testfile";
try {
List<List<Double>> datas = new ArrayList<List<Double>>();
List<List<Double>> testDatas = new ArrayList<List<Double>>();
t.read(datas, datafile);
t.read(testDatas, testfile);
KNN knn = new KNN();
for (int i = 0; i < testDatas.size(); i++) {
List<Double> test = testDatas.get(i);
System.out.print("測試元組: ");
for (int j = 0; j < test.size(); j++) {
System.out.print(test.get(j) + " ");
}
System.out.print("類別為: ");
System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3)))));
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
訓練數據文件:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5
程序運行結果:
測試元組: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 類別為: 1
測試元組: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 類別為: 1
測試元組: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 類別為: 1
測試元組: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 類別為: 0
測試元組: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 類別為: 1
測試元組: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 類別為: 0
由結果可以看出,分類的測試結果是比較準確的!
轉自:http://blog.csdn.net/luowen3405/article/details/6278764
參考:http://blog.csdn.net/xlm289348/article/details/8876353
http://coolshell.cn/articles/8052.html#more-8052