1. 程式人生 > >K-NN近鄰演算法詳解

K-NN近鄰演算法詳解

 K-近鄰演算法屬於一種監督學習分類演算法,該方法的思路是:如果一個樣本在特徵空間中的k個最相似(即特徵空間中最鄰近)的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別。

 (1) 需要進行分類,分類的依據是什麼呢,每個物體都有它的特徵點,這個就是分類的依據,特徵點可以是很多,越多分類就越精確。

 (2) 機器學習就是從樣本中學習分類的方式,那麼就需要輸入我們的樣本,也就是已經分好類的樣本,比如特徵點是A , B2個特徵,輸入的樣本甲乙丙丁,分別為[[1.0, 1.1], [1.0, 1.0], [5.0, 1.1], [5.0, 1.0]]。 那麼就開始輸入目標值,當然也要給特徵了,最終的目標就是看特徵接近A的多還是B的多,如果把這些當做座標,幾個特徵點就是幾緯座標,那麼就是座標之間的距離。那麼問題來了,要怎麼看接近A的多還是B的多。

工作原理:存在一個樣本資料集合,也稱為訓練樣本集,並且樣本集中每個資料都存在標籤,即我們知道樣本集中每一資料與所屬分類的對應關係。輸入沒有標籤的新資料後,將新資料的每個特徵與樣本集中資料對應的特徵進行比較,然後演算法提取樣本集中特徵最相似的資料,這就是k-近鄰演算法中k的出處,通常k是不大於20的整數。

#!/usr/bin/env python 
# -*- coding:utf-8 -*-

# 科學計算包
import numpy
# 運算子模組
import operator

# 資料樣本和分類模擬
# 手動建立一個數據源矩陣group 和資料來源的分類結果labels
def createDataSet():
    group = numpy.array([[1.0, 1.1], [1.0, 1.0], [5.0, 1.1], [5.0, 1.0]])
    lables = ['A', 'A', 'B', 'B']
    return group, lables

# 進行KNN 演算法
# newInput為輸入的目標,dataSet是樣本的矩陣,labels是分類,k是需要取的個數
def kNNClassify(newInput, dataSet, lables, k):
    # 讀取矩陣的行數,也就是樣本數量
    numSamples = dataSet.shape[0]
    print("numSamples = ", numSamples)

    # 變成和dataSet一樣的行數,行數 = 原來 * numSamples,列數 = 原來 * 1,然後每個特徵點和樣本的點進行相減
    # (numSamples, 1)表示矩陣newInput變為三維,重複次數一次
    diff = numpy.tile(newInput, (numSamples, 1)) - dataSet
    print("diff = ", diff)

    # 平方
    squaredDiff = diff ** 2
    print("squaredDiff = ", squaredDiff)

    # axis = 0 按列求和,axis = 1 按行求和
    squaredDist = numpy.sum(squaredDiff, axis=1)
    print("squaredDist = ", squaredDist)

    # 開根號,計算距離
    distance = squaredDist ** 0.5
    print("distance = ", distance)

    # 按大小逆序排序
    sortedDistIndices = numpy.argsort(distance)
    print("sortedDistIndices = ", sortedDistIndices)

    classCount = {}
    for i in range(k):
        # 返回距離(key)對應類別(value)
        voteLabel = labels[sortedDistIndices[i]]
        print("voteLabel = ", voteLabel)

        if voteLabel in classCount.keys():
            value = classCount[voteLabel]
            classCount[voteLabel] = value + 1
        else:
            classCount[voteLabel] = 1

    print("classCount: ", classCount)
    # 返回佔有率最大的
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    print("sortedClassCount = ", sortedClassCount)

    return sortedClassCount[0][0]

if __name__ == '__main__':
    dataSet, labels = createDataSet()

    testX = numpy.array([0, 0])
    k = 3
    outputLabel = kNNClassify(testX, dataSet, labels, k)
    print("Your input is:", testX, "and classified to class: ", outputLabel)


總結:K-近鄰演算法是分類資料中最簡單最有效的演算法,是基於例項的學習,使用演算法時我們必須要有接近實際資料的訓練樣本資料。K-近鄰演算法必須儲存全部資料集,如果訓練資料集很大,必須使用大量的儲存空間。此外,由於必須對資料集中的每個資料計算距離值,實際使用時可能非常耗時;

          k-近鄰演算法的另一個缺陷是讓無法給出任何資料的基礎結構資訊,因此我們也無法知曉平均例項樣本和典型例項樣本具有什麼特徵。