k近鄰法:一個簡單的預測模型
k近鄰法的思想
你打算預測我會在大選中投票給誰。假設你對我一無所知,一個明智的方法是看看我的鄰居們都投票給誰。當然,你可能還知道我的年齡、收入、有幾個孩子,等等,根據我的行為受這些維度影響的程度,你可以觀察在這些維度上最接近我的鄰居們而不是我所有的鄰居會得到更好的預測結果。這就是最近鄰分類(nearest neighbors classification)方法背后的思想。
k近鄰法的優缺點
k近鄰法是最簡單的預測模型之一,它沒有多少數學上的假設,也不要求任何復雜的數學處理,它所要求的僅僅是:
-
某種距離的概念
-
彼此接近的點具有相似性質的假設(否則用近鄰來預測結果就是不合理的)
k近鄰法有意忽略了大量信息,對每個新的數據點的預測只依賴少量最接近它的點。
此外,它不能解釋為什么。例如,基于我鄰居的投票行為來預測我的投票并不能告訴你我為什么要這樣投票,而基于我的收入、婚姻等因素來預測我的投票行為的模型則能揭示我投票的原因。
案例:最喜歡的編程語言
假設我們有一份數據,這份數據是各個城市經緯度及該城市最受歡迎的編程語言的集合。數據以列表存儲。
cities = [(-86.75,33.5666666666667,'Python'),(-88.25,30.6833333333333,'Python'),(-112.016666666667,33.4333333333333,'Java')......]
數據可視化
我們將每種語言及其經度(x)、緯度(y)按如下的格式存儲到字典中:鍵是語言,值是成對的數據。
plots={"Java":([],[]),"python":([],[]),"R":([],[])}
每種語言用不同的符號和顏色標記:
markers = { "Java" : "o", "Python" : "s", "R" : "^" } colors = { "Java" : "r", "Python" : "b", "R" : "g" }
將cities列表中的數據存放到plots字典中:
for (longitude, latitude), language in cities: plots[language][0].append(longitude) plots[language][1].append(latitude)
我們可以用items方法來返回字典元素的列表:
In [1]: plots.items() Out[1]: [('Python',([-86.75, -88.25, -118.15......],[33.5666666666667, 30.6833333333333, 33.8166666666667.....])),('R'......),('Java'.....)]
這樣我們就可以用for語句來循環遍歷字典元素,為每種語言創建一個散點序列:
import matplotlib.pyplot as pltfor language, (x, y) in plots.items(): plt.scatter(x, y, color=colors[language], marker=markers[language], label=language, zorder=10)
plt.legend(loc=0) #讓matplotlib選擇一個位置 plt.axis([-130,-60,20,55]) #設置坐標軸 plt.title("most popular language") #設置圖標標題 put.show()</pre>
最終效果如下:
k近鄰法的python實現
下面的點是cities列表中的第一個點,這個城市最受歡迎的編程語言是python:
In [2]: cities[0] Out[2]: ([-86.75, 33.5666666666667], 'Python')假設我們不知道這個城市最受歡迎的語言是什么。根據k近鄰法的思想,為了預測結果,(1)我們首先需要知道這個點即這個城市與其他所有點的距離,(2)然后找到離這個點最近的某個點或幾個點最受歡迎的編程語言是什么,以此作為預測結果,(3)如果是幾個點,我們需要計算哪種語言出現的次數最多,以此作為預測結果。
思路清楚了,讓我們來一步步實現吧。
先將除其他城市存放在other_cities列表中(這里用列表解析式遍歷所有城市,找到與該城市不同的所有城市):
other_cities=[other_city for other_city in cities if other_city != (cities[0][0],cities[0][1])]按其他城市與待預測城市之間的距離從近到遠對other_cities列表進行排序:
from linear_algebra import distanceby_distance=sorted(other_cities, key= lambda point_label: distance(point_label[0], cities[0][0]))</pre>
找到最近的一個城市:
In [3]: k_nearestlabels=[label for , label in by_distance[:1]]In [4]: k_nearest_labels Out[4]: ['Python']</pre>
當然,我們也可以找到最近的3個城市、5個城市或7個城市。
In [5]: k_nearestlabels=[label for , label in by_distance[:3]]In [6]: k_nearest_labels Out[6]: ['Python', 'R', 'Python']
In [7]: k_nearestlabels=[label for , label in by_distance[:5]]
In [8]: k_nearest_labels Out[8]: ['Python', 'R', 'Python', 'Java', 'R']
In [9]: k_nearestlabels=[label for , label in by_distance[:7]]
In [10]: k_nearest_labels Out[10]: ['Python', 'R', 'Python', 'Java', 'R', 'Python', 'Java']</pre>
這時,我們需要一個計數器找到出現次數最多的語言:
def majority_vote(labels): """假設labels已經從最近到最遠排序""" vote_counts = Counter(labels) #Counter返回的是字典,以label為鍵,出現次數為值 winner, winner_count = vote_counts.most_common(1)[0] #most_common方法可以找出vote_counts中出現次數前1(前幾由括號內參數指定)的鍵和值,以元組組織 num_winners = len([count for count in vote_counts.values() if count == winner_count]) #計算vote_counts中前1的出現次數出現了幾次,有幾個勝出者if num_winners == 1: return winner # 如果只有一個勝出者,直接返回 else: return majority_vote(labels[:-1]) # 如果有幾個勝出者,排除lavels中最遠的點,再試一次</pre>
在計算距離時,我們僅僅計算了一個點與其他點的距離。由于所有點的邏輯都是一樣的,這種情況下,我們可以構造函數來執行同類操作。
def knn_classify(k, labeled_points, new_point): """k決定取最近的幾個點;labeled_points指帶標簽的點,即(point, label)的數據對,是除待預測的點之外所有的點;new_ponit為待預測的點""" # 將帶標簽的點,即其他點從近到遠排序 by_distance = sorted(labeled_points, key=lambda point_label: distance(point_label[0], new_point)) # 找到最近的k個點 k_nearest_labels = [label for _, label in by_distance[:k]] # 對每個點進行計數 return majority_vote(k_nearest_labels)現在,讓我們看看嘗試利用近鄰城市來預測每個城市的偏愛語言會得到什么結果:
for k in [1, 3, 5, 7]: num_correct = 0 for location, actual_language in cities: other_cities = [other_city for other_city in cities if other_city != (location, actual_language)] predicted_language = knn_classify(k, other_cities, location) if predicted_language == actual_language: num_correct += 1 print(k, "neighbor[s]:", num_correct, "correct out of", len(cities))可以看到,k=3時的預測準確率最高,59%的時間能給出正確答案:
(1, 'neighbor[s]:', 40, 'correct out of', 75)
(3, 'neighbor[s]:', 44, 'correct out of', 75)
(5, 'neighbor[s]:', 41, 'correct out of', 75)
(7, 'neighbor[s]:', 35, 'correct out of', 75)
參考資料:
Joel Grus《數據科學入門》第12章
來自:https://segmentfault.com/a/1190000008766556