K近鄰(KNN)算法詳解及Python實現
KNN依然是一種監督學習算法
KNN(K Nearest Neighbors,K近鄰 )算法是機器學習所有算法中理論最簡單,最好理解的。KNN是一種基于實例的學習,通過計算新數據與訓練數據特征值之間的距離,然后選取 K(K>=1)個距離最近的鄰居進行分類判斷(投票法)或者回歸。如果K=1,那么新數據被簡單分配給其近鄰的類。KNN算法算是監督學習還是無監 督學習呢?首先來看一下監督學習和無監督學習的定義。對于監督學習,數據都有明確的label(分類針對離散分布,回歸針對連續分布),根據機器學習產生 的模型可以將新數據分到一個明確的類或得到一個預測值。對于非監督學習,數據沒有label,機器學習出的模型是從數據中提取出來的pattern(提取 決定性特征或者聚類等)。例如聚類是機器根據學習得到的模型來判斷新數據“更像”哪些原數據集合。KNN算法用于分類時,每個訓練數據都有明確的 label,也可以明確的判斷出新數據的label,KNN用于回歸時也會根據鄰居的值預測出一個明確的值,因此KNN屬于監督學習。
KNN算法的過程為:
- 選擇一種距離計算方式, 通過數據所有的特征計算新數據與已知類別數據集中的數據點的距離
- 按照距離遞增次序進行排序,選取與當前距離最小的k個點
- 對于離散分類,返回k個點出現頻率最多的類別作預測分類;對于回歸則返回k個點的加權值作為預測值
KNN算法關鍵
KNN算法的理論和過程就是那么簡單,為了使其獲得更好的學習效果,有下面幾個需要注意的地方。
-
數據的所有特征都要做可比較的量化。
若是數據特征中存在非數值的類型,必須采取手段將其量化為數值。舉個例子,若樣本特征中包含顏色(紅 黑藍)一項,顏色之間是沒有距離可言的,可通過將顏色轉換為灰度值來實現距離計算。另外,樣本有多個參數,每一個參數都有自己的定義域和取值范圍,他們對 distance計算的影響也就不一樣,如取值較大的影響力會蓋過取值較小的參數。為了公平,樣本參數必須做一些scale處理,最簡單的方式就是所有特 征的數值都采取歸一化處置。 -
需要一個distance函數以計算兩個樣本之間的距離。
距離的定義有很多,如歐氏距離、余弦距離、漢明距離、曼哈頓距離等等,關于相似性度量的方法可參考:機器學習中距離和相似性度量方法 。 一般情況下,選歐氏距離作為距離度量,但是這是只適用于連續變量。在文本分類這種非連續變量情況下,漢明距離可以用來作為度量。通常情況下,如果運用一些特殊的算法來計算度量的話,K近鄰分類精度可顯著提高,如運用大邊緣最近鄰法或者近鄰成分分析法。 -
確定K的值
K是一個自定義的常數,K的值也直接影響最后的估計,一種選擇K值得方法是使用 cross-validate(交叉驗證)誤差統計選擇法。交叉驗證的概念之前提過,就是數據樣本的一部分作為訓練樣本,一部分作為測試樣本,比如選擇 95%作為訓練樣本,剩下的用作測試樣本。通過訓練數據訓練一個機器學習模型,然后利用測試數據測試其誤差率。 cross-validate(交叉驗證)誤差統計選擇法就是比較不同K值時的交叉驗證平均誤差率,選擇誤差率最小的那個K值。例如選擇 K=1,2,3,... , 對每個K=i做100次交叉驗證,計算出平均誤差,然后比較、選出最小的那個。
KNN分類
訓練樣本是多維特征空間向量,其中每個訓練樣本帶有一個類別標簽(喜歡或者不喜歡、保留或者刪除)。分類算法常采用“多數表決”決定,即k個鄰居中 出現次數最多的那個類作為預測類。“多數表決”分類的一個缺點是出現頻率較多的樣本將會主導測試點的預測結果,那是因為他們比較大可能出現在測試點的K鄰 域而測試點的屬性又是通過K領域內的樣本計算出來的。解決這個缺點的方法之一是在進行分類時將K個鄰居到測試點的距離考慮進去。例如,若樣本到測試點距離 為d,則選1/d為該鄰居的權重(也就是得到了該鄰居所屬類的權重),接下來統計統計k個鄰居所有類標簽的權重和,值最大的那個就是新數據點的預測類標 簽。
舉例,K=5,計算出新數據點到最近的五個鄰居的舉例是(1,3,3,4,5),五個鄰居的類標簽是(yes,no,no,yes,no)
若是按照多數表決法,則新數據點類別為no(3個no,2個yes);若考慮距離權重類別則為yes(no:2/3+1/5,yes:1+1/4)。
下面的Python程序是采用KNN算法的實例(計算歐氏距離,多數表決法決斷):一個是采用KNN算法改進約會網站配對效果,另一個是采用KNN算法進行手寫識別。
約會網站配對效果改進的例子是根據男子的每年的飛行里程、視頻游戲時間比和每周冰激凌耗量三個特征來判斷其是否是海倫姑娘喜歡的類型(類別為很喜歡、一般和討厭),決策采用多數表決法。由于三個特征的取值范圍不同,這里采用的scale策略為歸一化。
使用KNN分類器的手寫識別系統 只能識別數字0到9。需要識別的數字使用圖形處理軟件,處理成具有相同的色 彩和大小 :寬髙是32像素X32像素的黑白圖像。盡管采用文本格式存儲圖像不能有效地利用內存空間,為了方便理解,這里已經將將圖像轉換為文本格式。訓練數據中每 個數字大概有200個樣本,程序中將圖像樣本格式化處理為向量,即一個把一個32x32的二進制圖像矩陣轉換為一個1x1024的向量。
from numpy import * import operator from os import listdir import matplotlib import matplotlib.pyplot as plt import pdbdef classify0(inX, dataSet, labels, k=3):
#pdb.set_trace() dataSetSize = dataSet.shape[0] diffMat = tile(inX, (dataSetSize,1)) - dataSet sqDiffMat = diffMat**2 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances**0.5 sortedDistIndicies = distances.argsort() #ascend sorted, #return the index of unsorted, that is to choose the least 3 item classCount={} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1# a dict with label as key and occurrence number as value sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) '''descend sorted according to value, ''' return sortedClassCount[0][0]
def file2matrix(filename): fr = open(filename)
#pdb.set_trace() L = fr.readlines() numberOfLines = len(L) #get the number of lines in the file returnMat = zeros((numberOfLines,3)) #prepare matrix to return classLabelVector = [] #prepare labels return index = 0 for line in L: line = line.strip() listFromLine = line.split('\t') returnMat[index,:] = listFromLine[0:3] classLabelVector.append(int(listFromLine[-1])) #classLabelVector.append((listFromLine[-1])) index += 1 fr.close() return returnMat,classLabelVector
def plotscattter(): datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file fig = plt.figure() ax1 = fig.add_subplot(111) ax2 = fig.add_subplot(111) ax3 = fig.add_subplot(111) ax1.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0array(datingLabels),15.0array(datingLabels))
#ax2.scatter(datingDataMat[:,0],datingDataMat[:,2],15.0*array(datingLabels),15.0*array(datingLabels)) #ax2.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*array(datingLabels),15.0*array(datingLabels)) plt.show()
def autoNorm(dataSet): minVals = dataSet.min(0) maxVals = dataSet.max(0) ranges = maxVals - minVals normDataSet = zeros(shape(dataSet)) m = dataSet.shape[0] normDataSet = dataSet - tile(minVals, (m,1)) normDataSet = normDataSet/tile(ranges, (m,1)) #element wise divide return normDataSet, ranges, minVals
def datingClassTest(hoRatio = 0.20):
#hold out 10% datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file normMat, ranges, minVals = autoNorm(datingDataMat) m = normMat.shape[0] numTestVecs = int(m*hoRatio) errorCount = 0.0 for i in range(numTestVecs): classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3) print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]) if (classifierResult != datingLabels[i]): errorCount += 1.0 print "the total error rate is: %.2f%%" % (100*errorCount/float(numTestVecs)) print 'testcount is %s, errorCount is %s' %(numTestVecs,errorCount)
def classifyPerson(): ''' input a person , decide like or not, then update the DB ''' resultlist = ['not at all','little doses','large doses'] percentTats = float(raw_input('input the person\' percentage of time playing video games:')) ffMiles = float(raw_input('flier miles in a year:')) iceCream = float(raw_input('amount of iceCream consumed per year:')) datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') normMat, ranges, minVals = autoNorm(datingDataMat) normPerson = (array([ffMiles,percentTats,iceCream])-minVals)/ranges result = classify0(normPerson, normMat, datingLabels, 3) print 'you will probably like this guy in:', resultlist[result -1]
#update the datingTestSet print 'update dating DB' tmp = '\t'.join([repr(ffMiles),repr(percentTats),repr(iceCream),repr(result)])+'\n' with open('datingTestSet2.txt','a') as fr: fr.write(tmp)
def img2file(filename):
#vector = zeros(1,1024) with open(filename) as fr: L=fr.readlines() vector =[int(L[i][j]) for i in range(32) for j in range(32)] return array(vector,dtype = float)
def handwritingClassTest(): hwLabels = [] trainingFileList = listdir('trainingDigits') #load the training set m = len(trainingFileList) trainingMat = zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] fileStr = fileNameStr.split('.')[0] #take off .txt classNumStr = int(fileStr.split('')[0]) hwLabels.append(classNumStr) trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr) testFileList = listdir('testDigits') #iterate through the test set errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] #take off .txt classNumStr = int(fileStr.split('')[0]) vectorUnderTest = img2vector('testDigits/%s' % fileNameStr) classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr) if (classifierResult != classNumStr): errorCount += 1.0 print "\nthe total number of errors is: %d" % errorCount print "\nthe total error rate is: %f" % (errorCount/float(mTest))
if name == 'main': datingClassTest()
#handwritingClassTest()</pre> <h3> KNN回歸</h3>
數據點的類別標簽是連續值時應用KNN算法就是回歸,與KNN分類算法過程相同,區別在于對K個鄰居的處理上。KNN回歸是取K個鄰居類標簽值得加 權作為新數據點的預測值。加權方法有:K個近鄰的屬性值的平均值(最差)、1/d為權重(有效的衡量鄰居的權重,使較近鄰居的權重比較遠鄰居的權重大)、 高斯函數(或者其他適當的減函數)計算權重= gaussian(distance) (距離越遠得到的值就越小,加權得到更為準確的估計。
總結
K-近鄰算法是分類數據最簡單最有效的算法,其學習基于實例,使用算法時我們必須有接近實際數據的訓練樣本數據。K-近鄰算法必須保存全部數據集, 如果訓練數據集的很大,必須使用大量的存儲空間。此外,由于必須對數據集中的每個數據計算距離值,實際使用時可能非常耗時。k-近鄰算法的另一個缺陷是它 無法給出任何數據的基礎結構信息,因此我們也無法知曉平均實例樣本和典型實例樣本具有什么特征。
來自:http://suanfazu.com/t/kjin-lin-knn-suan-fa-xiang-jie-ji-pythonshi-xian/114/1