1. 程式人生 > >手寫字型識別 --MNIST資料集

手寫字型識別 --MNIST資料集

Matlab 手寫字型識別

忙過這段時間後,對於上次讀取的Matlab內部資料實現的識別,我回味了一番,覺得那個實在太小。所以打算把資料換成[MNIST資料集][1]。

基礎思想還是相同的,使用TreeBagger(隨機森林)的演算法來訓練樣本,從而實現學習並且識別。這一次不會和上次那麼草率了….同時分享一些關於TreeBagger的理解。

思想
和我上一個識別花是一樣的。使用演算法訓練訓練樣本,得到一個模型model,從而使用predict函式根據模型對測試樣本進行識別。從而達到手寫字型識別的效果。這裡我使用了Google實驗室的Corinna Cortes和紐約大學柯朗研究所的Yann LeCun的建有一個手寫數字資料庫,訓練庫有60,000張手寫數字影象,測試庫有10,000張。

但是因為這個網站上的四個檔案資源似乎不多,導致下載速度很慢,所以我把他們放在我的雲盤裡,給大家下載。雲盤連結 密碼:7awp
因為裡面的內容全部是用二進位制存的,所以我在的檔案裡也順便把解壓的.m的檔案放進去了,也省得大家到處找。

程式碼實現

因為要匯入的圖片太多,一開始我使用imread時我發現,imread似乎是按照一個特定的檔名順序讀取檔案的,所以對於我這些有順序的圖片,他不能按照順序讀。所以我自己想了個方法來讀取60000張訓練樣本。

for i=1:60000
    str  = strcat('C:\Users\StevenT\Desktop\mnist資料集\train-images-idx3-ubyte\TrainImage_'
,num2str(i)); name = strcat(str,'.bmp'); name = char(name); current_img = imread(name); %將當前圖片賦值給一個變數 current_img = reshape(current_img,1,[]); %將矩陣變形 train_image(i,:) = current_img; end

之後用同樣的方法獲得10000個測試樣本。
對於測試標籤和訓練標籤的讀取,直接用textread來讀取就可以了。

lable_test = textread('C:\Users\StevenT
\Desktop\mnist資料集\t10k-labels-idx1-ubyte\test_lable.txt');

讀取到樣本和標籤之後,對樣本和標籤進行訓練。

model = TreeBagger(500,train_image,lable_train); %使用TreeBagger來對訓練樣本進行訓練,獲得一個model

result = predict(model,test_image);  %之後使用model來對測試樣本進行預測,將結果存在resultresult = cell2mat(result);   %因為result是cell類的,使用cell2mat轉換成字串

最後輸出識別率

sc=double(result) - lable_test; 
count=sum(sc(:)==48)/100.0;     %sc用來儲存相減的結果,當其等於0(ASCII裡是48)的時候就是識別正確的結果,最終得出識別率

整體程式碼

clear all;
clc;

%匯入訓練樣本

for i=1:60000
    str  = strcat('C:\Users\StevenT\Desktop\mnist資料集\train-images-idx3-ubyte\TrainImage_',num2str(i));
    name = strcat(str,'.bmp');
    name = char(name);
    current_img = imread(name);
    current_img = reshape(current_img,1,[]); %將矩陣變形
    train_image(i,:) = current_img;
end
train_image=double(train_image);

%匯入測試樣本

for i=1:10000
    str  = strcat('C:\Users\StevenT\Desktop\mnist資料集\t10k-images-idx3-ubyte\TestImage_',num2str(i));
    name = strcat(str,'.bmp');
    name = char(name);
    current_img = imread(name);
    current_img = reshape(current_img,1,[]); %將矩陣變形
    test_image(i,:) = current_img;
end
test_image=double(test_image);

lable_test = textread('C:\Users\StevenT\Desktop\mnist資料集\t10k-labels-idx1-ubyte\test_lable.txt');
lable_train = textread('C:\Users\StevenT\Desktop\mnist資料集\train-labels-idx1-ubyte\train_lable.txt');
% lable_train = lable_train(1:100);
% lable_test = lable_test(1:100);

model = TreeBagger(500,train_image,lable_train); %使用TreeBagger來對訓練樣本進行訓練,獲得一個model

result = predict(model,test_image);  %之後使用model來對測試樣本進行預測,將結果存在result內

result = cell2mat(result);   %因為result是cell類的,使用cell2mat轉換成字串

sc=double(result) - lable_test; 
count=sum(sc(:)==48)/100.0;     %sc用來儲存相減的結果,當其等於0(ASCII裡是48)的時候就是識別正確的結果,最終得出識別率

這裡請大家把地址改成自己的地址。

這是執行後的工作區
工作區

count是識別的準確率,已經達到96.82%了
我的在TreeBagger裡用了50棵決策樹,同時我也嘗試過500棵,講真跑的很久,但是準確率卻只提高了1%,所以我認識到這個決策樹的個數是沒很大影響的…(我的電腦跑的心好累)

這個例子中最重要的莫過於隨機森林TreeBagger這個函式所以我在這裡發一份部落格,我覺得挺好懂的一份(好吧其實是因為有圖)
http://www.36dsj.com/archives/21036

關於演算法的問題吧,我覺得如果不是想搞演算法開發的,還是會用就好:)

以上就是我做的一個小小的手寫識別,後面我會把我的用GUI把自己手寫的數字識別出來的小補充發出來。同時呢,正在學習SVM ,過段時間學的好的話,我會這個也發出來~~

大家共勉:)