K-NN近鄰演算法詳解
阿新 • • 發佈:2018-11-07
K-近鄰演算法屬於一種監督學習分類演算法,該方法的思路是:如果一個樣本在特徵空間中的k個最相似(即特徵空間中最鄰近)的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別。
(1) 需要進行分類,分類的依據是什麼呢,每個物體都有它的特徵點,這個就是分類的依據,特徵點可以是很多,越多分類就越精確。
(2) 機器學習就是從樣本中學習分類的方式,那麼就需要輸入我們的樣本,也就是已經分好類的樣本,比如特徵點是A , B2個特徵,輸入的樣本甲乙丙丁,分別為[[1.0, 1.1], [1.0, 1.0], [5.0, 1.1], [5.0, 1.0]]。 那麼就開始輸入目標值,當然也要給特徵了,最終的目標就是看特徵接近A的多還是B的多,如果把這些當做座標,幾個特徵點就是幾緯座標,那麼就是座標之間的距離。那麼問題來了,要怎麼看接近A的多還是B的多。
工作原理:存在一個樣本資料集合,也稱為訓練樣本集,並且樣本集中每個資料都存在標籤,即我們知道樣本集中每一資料與所屬分類的對應關係。輸入沒有標籤的新資料後,將新資料的每個特徵與樣本集中資料對應的特徵進行比較,然後演算法提取樣本集中特徵最相似的資料,這就是k-近鄰演算法中k的出處,通常k是不大於20的整數。
#!/usr/bin/env python # -*- coding:utf-8 -*- # 科學計算包 import numpy # 運算子模組 import operator # 資料樣本和分類模擬 # 手動建立一個數據源矩陣group 和資料來源的分類結果labels def createDataSet(): group = numpy.array([[1.0, 1.1], [1.0, 1.0], [5.0, 1.1], [5.0, 1.0]]) lables = ['A', 'A', 'B', 'B'] return group, lables # 進行KNN 演算法 # newInput為輸入的目標,dataSet是樣本的矩陣,labels是分類,k是需要取的個數 def kNNClassify(newInput, dataSet, lables, k): # 讀取矩陣的行數,也就是樣本數量 numSamples = dataSet.shape[0] print("numSamples = ", numSamples) # 變成和dataSet一樣的行數,行數 = 原來 * numSamples,列數 = 原來 * 1,然後每個特徵點和樣本的點進行相減 # (numSamples, 1)表示矩陣newInput變為三維,重複次數一次 diff = numpy.tile(newInput, (numSamples, 1)) - dataSet print("diff = ", diff) # 平方 squaredDiff = diff ** 2 print("squaredDiff = ", squaredDiff) # axis = 0 按列求和,axis = 1 按行求和 squaredDist = numpy.sum(squaredDiff, axis=1) print("squaredDist = ", squaredDist) # 開根號,計算距離 distance = squaredDist ** 0.5 print("distance = ", distance) # 按大小逆序排序 sortedDistIndices = numpy.argsort(distance) print("sortedDistIndices = ", sortedDistIndices) classCount = {} for i in range(k): # 返回距離(key)對應類別(value) voteLabel = labels[sortedDistIndices[i]] print("voteLabel = ", voteLabel) if voteLabel in classCount.keys(): value = classCount[voteLabel] classCount[voteLabel] = value + 1 else: classCount[voteLabel] = 1 print("classCount: ", classCount) # 返回佔有率最大的 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) print("sortedClassCount = ", sortedClassCount) return sortedClassCount[0][0] if __name__ == '__main__': dataSet, labels = createDataSet() testX = numpy.array([0, 0]) k = 3 outputLabel = kNNClassify(testX, dataSet, labels, k) print("Your input is:", testX, "and classified to class: ", outputLabel)
總結:K-近鄰演算法是分類資料中最簡單最有效的演算法,是基於例項的學習,使用演算法時我們必須要有接近實際資料的訓練樣本資料。K-近鄰演算法必須儲存全部資料集,如果訓練資料集很大,必須使用大量的儲存空間。此外,由於必須對資料集中的每個資料計算距離值,實際使用時可能非常耗時;
k-近鄰演算法的另一個缺陷是讓無法給出任何資料的基礎結構資訊,因此我們也無法知曉平均例項樣本和典型例項樣本具有什麼特徵。