機器學習--k-近鄰(kNN)算法

jopen 9年前發布 | 16K 次閱讀 算法

  一、基本原理

        存在一個樣本數據集合(也稱訓練樣本集),并且樣本集中每個數據都存在標簽。輸入沒有標簽的新數據后,將新數據的每個特征與樣本集中數據對應的特征進行比較,然后算法提取樣本集中特征最相似數據(最近鄰)的分類標簽。

我們一般只選擇樣本集中前k(k通常是不大于20的整數)個最相似的數據,最后選擇k個最相似數據中出現次數最多的分類,作為新數據的分類。

二、算法流程

        1)計算已知類別數據集中的點與當前點之間的距離;

        2)按照距離遞增次序排序;
</div>

        3)選取與當前點距離最小的k個點;
</div>

        4)確定前k個點所在類別的出現頻率;
</div>

        5)返回前k個點出現頻率最高的類別作為當前點的預測分類。
</div>

三、算法的特點

        優點:精度高、對異常值不敏感、無數據輸入假定。

        缺點:計算復雜度高、空間復雜度高。

        適用數據范圍:數值型和標稱型。

四、python代碼實現

1、創建數據集

def create_data_set():
    group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    labels = ['A', 'A', 'B', 'B']
    return group, labels

2、實施KNN算法

##############################

功能:將每組數據劃分到某個類中

輸入變量:inx, data_set,labels,k

分類的向量,樣本數據,標簽,k個近鄰的樣本

輸出變量:sorted_class_count[0][0] 選擇最近的類別標簽

######################## </p>

def classify0(inx, data_set, labels, k):
    data_set_size = data_set.shape[0]  # 獲得數組的行數

    # 利用tile(inx, (data_set_size, 1)),在原來的基礎上再構造data_set_size*1的inx
    # 每行數據相當于某個矢量點的坐標
    # 對每行數據進行求和,得到一個data_set_size*1的矩陣
    # 最后計算歐式距離
    diff_mat = tile(inx, (data_set_size, 1))-data_set
    sq_diff_mat = diff_mat**2
    sq_distances = sq_diff_mat.sum(axis=1)
    distances = sq_distances**0.5

     # argsort函數返回的是數組值從小到大的索引值
    sorted_dist_indicies = distances.argsort()

    class_count = {}
    for i in xrange(k):
        vote_label = labels[sorted_dist_indicies[i]]

        # get相當于一條if...else...語句
        # 如果參數vote_label不在字典中則返回參數0,如果vote_label在字典中則返回vote_label對應的value值
        class_count[vote_label] = class_count.get(vote_label, 0) + 1

    # items以列表方式返回字典中的鍵值對,iteritems以迭代器對象返回鍵值對,而鍵值對以元組方式存儲,即這種方式[(), ()]
    # operator.itemgetter(0)獲取對象的第0個域的值,即返回的是key值
    # operator.itemgetter(1)獲取對象的第1個域的值,即返回的是value值
    # operator.itemgetter定義了一個函數,通過該函數作用到對象上才能獲取值
    # reverse=True是按降序排序
    sorted_class_count = sorted(class_count.iteritems(), key=operator.itemgetter(1), reverse=True)

    return sorted_class_count[0][0]

</div>

3、代碼測試

def main():
    group, labels = create_data_set()
    sorted_class_labels = classify0([0, 0], group, labels, 3)
    print 'sorted_class_labels=', sorted_class_labels

if __name__ == '__main__':
    main()

來自:http://blog.csdn.net/zyl1042635242/article/details/45081065

 本文由用戶 jopen 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
 轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
 本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!