1. 程式人生 > >機器學習--k-近鄰演算法(kNN)實現手寫數字識別

機器學習--k-近鄰演算法(kNN)實現手寫數字識別

這裡的手寫數字以0,1的形式儲存在文字檔案中,大小是32x32.目錄trainingDigits有1934個樣本。0-9每個數字大約有200個樣本,命名規則如下:

下劃線前的數字代表是樣本0-9的數字,下劃線後的數字代表是當前數字的第多少個樣本。

目錄testDigits下有946個樣本。這個資料集可以在網上下載。

首先將32x32的二進位制影象矩陣轉換為1x1024的向量。

def img2vector(filename):
    returnVect = zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0, 32*i+j] = int(lineStr[j])
    return returnVect
def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    #下面的四行程式碼計算距離
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances ** 0.5
    #對距離進行排序
    sortedDistIndicies = distances.argsort()
    classCount = {}
    #確定前k個較小距離的類別
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    #獲得最大頻率的類別
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

現在,可以檢測一下kNN分類器的效果了。

def handwritingClassTest():
    hwLabels = []
    #獲取目錄內容
    trainingFileList = listdir('digits/trainingDigits')
    m = len(trainingFileList)
    traningMat = zeros((m, 1024))
    for i in range(m):
        #從檔名解析分類數字
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        traningMat[i, :] = img2vector('digits/trainingDigits/%s' % fileNameStr)
    testFileList = listdir('digits/testDigits')
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, traningMat, hwLabels, 3)
        print("the classifier came back with: %d, the real answer is: %d\n" % (classifierResult, classNumStr))
        if(classifierResult != classNumStr):
            errorCount += 1.0
    print("the total number of errors is: %d\n" % errorCount)
    print("the total error rate is: %f" % (errorCount/float(mTest)))

將上面的幾段程式碼儲存為kNN.py,然後在終端執行如下操作:

最後的輸出如下: