機器學習與數據挖掘-K最近鄰(KNN)算法的實現(java和python版)
KNN算法基礎思想前面文章可以參考,這里主要講解java和python的兩種簡單實現,也主要是理解簡單的思想。
http://blog.csdn.net/u011067360/article/details/23941577
python版本:
這里實現一個手寫識別算法,這里只簡單識別0~9熟悉,在上篇文章中也展示了手寫識別的應用,可以參考:機器學習與數據挖掘-logistic回歸及手寫識別實例的實現
輸入:每個手寫數字已經事先處理成32*32的二進制文本,存儲為txt文件。0~9每個數字都有10個訓練樣本,5個測試樣本。訓練樣本集如下圖:左邊是文件目錄,右邊是其中一個文件打開顯示的結果,看著像1,這里有0~9,每個數字都有是個樣本來作為訓練集。
第一步:將每個txt文本轉化為一個向量,即32*32的數組轉化為1*1024的數組,這個1*1024的數組用機器學習的術語來說就是特征向量。
<span style="font-size:14px;">def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i in range(32): lineStr = fr.readline() for j in range(32): returnVect[0,32*i+j] = int(lineStr[j]) return returnVect</span>
第二步:訓練樣本中有10*10個圖片,可以合并成一個100*1024的矩陣,每一行對應一個圖片,也就是一個txt文檔。
def handwritingClassTest(): hwLabels = [] trainingFileList = listdir('trainingDigits') print trainingFileList m = len(trainingFileList) trainingMat = zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] fileStr = fileNameStr.split('.')[0] classNumStr = int(fileStr.split('_')[0]) hwLabels.append(classNumStr) #print hwLabels #print fileNameStr trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr) #print trainingMat[i,:] #print len(trainingMat[i,:]) testFileList = listdir('testDigits') errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] classNumStr = int(fileStr.split('_')[0]) vectorUnderTest = img2vector('testDigits/%s' % fileNameStr) classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr) if (classifierResult != classNumStr): errorCount += 1.0 print "\nthe total number of errors is: %d" % errorCount print "\nthe total error rate is: %f" % (errorCount/float(mTest))
第三步:測試樣本中有10*5個圖片,同樣的,對于測試圖片,將其轉化為1*1024的向量,然后計算它與訓練樣本中各個圖片的“距離”(這里兩個向量的距離采用歐式距離),然后對距離排序,選出較小的前k個,因為這k個樣本來自訓練集,是已知其代表的數字的,所以被測試圖片所代表的數字就可以確定為這k個中出現次數最多的那個數字。
def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0] #tile(A,(m,n)) print dataSet print "----------------" print tile(inX, (dataSetSize,1)) print "----------------" diffMat = tile(inX, (dataSetSize,1)) - dataSet print diffMat sqDiffMat = diffMat**2 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances**0.5 sortedDistIndicies = distances.argsort() classCount={} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]
#-*-coding:utf-8-*- from numpy import * import operator from os import listdir def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0] #tile(A,(m,n)) print dataSet print "----------------" print tile(inX, (dataSetSize,1)) print "----------------" diffMat = tile(inX, (dataSetSize,1)) - dataSet print diffMat sqDiffMat = diffMat**2 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances**0.5 sortedDistIndicies = distances.argsort() classCount={} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i in range(32): lineStr = fr.readline() for j in range(32): returnVect[0,32*i+j] = int(lineStr[j]) return returnVect def handwritingClassTest(): hwLabels = [] trainingFileList = listdir('trainingDigits') print trainingFileList m = len(trainingFileList) trainingMat = zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] fileStr = fileNameStr.split('.')[0] classNumStr = int(fileStr.split('_')[0]) hwLabels.append(classNumStr) #print hwLabels #print fileNameStr trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr) #print trainingMat[i,:] #print len(trainingMat[i,:]) testFileList = listdir('testDigits') errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] classNumStr = int(fileStr.split('_')[0]) vectorUnderTest = img2vector('testDigits/%s' % fileNameStr) classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr) if (classifierResult != classNumStr): errorCount += 1.0 print "\nthe total number of errors is: %d" % errorCount print "\nthe total error rate is: %f" % (errorCount/float(mTest)) handwritingClassTest()
運行結果:源碼文章尾可下載
java版本
先看看訓練集和測試集:
訓練集:
測試集:
訓練集最后一列代表分類(0或者1)
代碼實現:
KNN算法主體類:
package Marchinglearning.knn2; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.PriorityQueue; /** * KNN算法主體類 */ public class KNN { /** * 設置優先級隊列的比較函數,距離越大,優先級越高 */ private Comparator<KNNNode> comparator = new Comparator<KNNNode>() { public int compare(KNNNode o1, KNNNode o2) { if (o1.getDistance() >= o2.getDistance()) { return 1; } else { return 0; } } }; /** * 獲取K個不同的隨機數 * @param k 隨機數的個數 * @param max 隨機數最大的范圍 * @return 生成的隨機數數組 */ public List<Integer> getRandKNum(int k, int max) { List<Integer> rand = new ArrayList<Integer>(k); for (int i = 0; i < k; i++) { int temp = (int) (Math.random() * max); if (!rand.contains(temp)) { rand.add(temp); } else { i--; } } return rand; } /** * 計算測試元組與訓練元組之前的距離 * @param d1 測試元組 * @param d2 訓練元組 * @return 距離值 */ public double calDistance(List<Double> d1, List<Double> d2) { System.out.println("d1:"+d1+",d2"+d2); double distance = 0.00; for (int i = 0; i < d1.size(); i++) { distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i)); } return distance; } /** * 執行KNN算法,獲取測試元組的類別 * @param datas 訓練數據集 * @param testData 測試元組 * @param k 設定的K值 * @return 測試元組的類別 */ public String knn(List<List<Double>> datas, List<Double> testData, int k) { PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator); List<Integer> randNum = getRandKNum(k, datas.size()); System.out.println("randNum:"+randNum.toString()); for (int i = 0; i < k; i++) { int index = randNum.get(i); List<Double> currData = datas.get(index); String c = currData.get(currData.size() - 1).toString(); System.out.println("currData:"+currData+",c:"+c+",testData"+testData); //計算測試元組與訓練元組之前的距離 KNNNode node = new KNNNode(index, calDistance(testData, currData), c); pq.add(node); } for (int i = 0; i < datas.size(); i++) { List<Double> t = datas.get(i); System.out.println("testData:"+testData); System.out.println("t:"+t); double distance = calDistance(testData, t); System.out.println("distance:"+distance); KNNNode top = pq.peek(); if (top.getDistance() > distance) { pq.remove(); pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString())); } } return getMostClass(pq); } /** * 獲取所得到的k個最近鄰元組的多數類 * @param pq 存儲k個最近近鄰元組的優先級隊列 * @return 多數類的名稱 */ private String getMostClass(PriorityQueue<KNNNode> pq) { Map<String, Integer> classCount = new HashMap<String, Integer>(); for (int i = 0; i < pq.size(); i++) { KNNNode node = pq.remove(); String c = node.getC(); if (classCount.containsKey(c)) { classCount.put(c, classCount.get(c) + 1); } else { classCount.put(c, 1); } } int maxIndex = -1; int maxCount = 0; Object[] classes = classCount.keySet().toArray(); for (int i = 0; i < classes.length; i++) { if (classCount.get(classes[i]) > maxCount) { maxIndex = i; maxCount = classCount.get(classes[i]); } } return classes[maxIndex].toString(); } }
KNN結點類,用來存儲最近鄰的k個元組相關的信息
package Marchinglearning.knn2; /** * KNN結點類,用來存儲最近鄰的k個元組相關的信息 */ public class KNNNode { private int index; // 元組標號 private double distance; // 與測試元組的距離 private String c; // 所屬類別 public KNNNode(int index, double distance, String c) { super(); this.index = index; this.distance = distance; this.c = c; } public int getIndex() { return index; } public void setIndex(int index) { this.index = index; } public double getDistance() { return distance; } public void setDistance(double distance) { this.distance = distance; } public String getC() { return c; } public void setC(String c) { this.c = c; } }
KNN算法測試類
package Marchinglearning.knn2; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.util.ArrayList; import java.util.List; /** * KNN算法測試類 */ public class TestKNN { /** * 從數據文件中讀取數據 * @param datas 存儲數據的集合對象 * @param path 數據文件的路徑 */ public void read(List<List<Double>> datas, String path){ try { BufferedReader br = new BufferedReader(new FileReader(new File(path))); String data = br.readLine(); List<Double> l = null; while (data != null) { String t[] = data.split(" "); l = new ArrayList<Double>(); for (int i = 0; i < t.length; i++) { l.add(Double.parseDouble(t[i])); } datas.add(l); data = br.readLine(); } } catch (Exception e) { e.printStackTrace(); } } /** * 程序執行入口 * @param args */ public static void main(String[] args) { TestKNN t = new TestKNN(); String datafile = new File("").getAbsolutePath() + File.separator +"knndata2"+File.separator + "datafile.data"; String testfile = new File("").getAbsolutePath() + File.separator +"knndata2"+File.separator +"testfile.data"; System.out.println("datafile:"+datafile); System.out.println("testfile:"+testfile); try { List<List<Double>> datas = new ArrayList<List<Double>>(); List<List<Double>> testDatas = new ArrayList<List<Double>>(); t.read(datas, datafile); t.read(testDatas, testfile); KNN knn = new KNN(); for (int i = 0; i < testDatas.size(); i++) { List<Double> test = testDatas.get(i); System.out.print("測試元組: "); for (int j = 0; j < test.size(); j++) { System.out.print(test.get(j) + " "); } System.out.print("類別為: "); System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3))))); } } catch (Exception e) { e.printStackTrace(); } } }
運行結果為:
資源下載:
來自: http://blog.csdn.net//u011067360/article/details/45937327
本文由用戶 jopen 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!