機器學習--k-近鄰演算法(kNN)實現手寫數字識別
阿新 • • 發佈:2018-12-11
這裡的手寫數字以0,1的形式儲存在文字檔案中,大小是32x32.目錄trainingDigits有1934個樣本。0-9每個數字大約有200個樣本,命名規則如下:
下劃線前的數字代表是樣本0-9的數字,下劃線後的數字代表是當前數字的第多少個樣本。
目錄testDigits下有946個樣本。這個資料集可以在網上下載。
首先將32x32的二進位制影象矩陣轉換為1x1024的向量。
def img2vector(filename): returnVect = zeros((1, 1024)) fr = open(filename) for i in range(32): lineStr = fr.readline() for j in range(32): returnVect[0, 32*i+j] = int(lineStr[j]) return returnVect
def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0] #下面的四行程式碼計算距離 diffMat = tile(inX, (dataSetSize, 1)) - dataSet sqDiffMat = diffMat ** 2 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances ** 0.5 #對距離進行排序 sortedDistIndicies = distances.argsort() classCount = {} #確定前k個較小距離的類別 for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 #獲得最大頻率的類別 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]
現在,可以檢測一下kNN分類器的效果了。
def handwritingClassTest(): hwLabels = [] #獲取目錄內容 trainingFileList = listdir('digits/trainingDigits') m = len(trainingFileList) traningMat = zeros((m, 1024)) for i in range(m): #從檔名解析分類數字 fileNameStr = trainingFileList[i] fileStr = fileNameStr.split('.')[0] classNumStr = int(fileStr.split('_')[0]) hwLabels.append(classNumStr) traningMat[i, :] = img2vector('digits/trainingDigits/%s' % fileNameStr) testFileList = listdir('digits/testDigits') errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] classNumStr = int(fileStr.split('_')[0]) vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr) classifierResult = classify0(vectorUnderTest, traningMat, hwLabels, 3) print("the classifier came back with: %d, the real answer is: %d\n" % (classifierResult, classNumStr)) if(classifierResult != classNumStr): errorCount += 1.0 print("the total number of errors is: %d\n" % errorCount) print("the total error rate is: %f" % (errorCount/float(mTest)))
將上面的幾段程式碼儲存為kNN.py,然後在終端執行如下操作:
最後的輸出如下: