1. 程式人生 > >Tensorflow深度學習之七:再談mnist手寫數字識別程式

Tensorflow深度學習之七:再談mnist手寫數字識別程式

之前學習的第一個深度學習的程式就是mnist手寫字型的識別,那個時候對於很多概念不是很理解,現在回過頭再看當時的程式碼,理解了很多,現將加了註釋的程式碼貼上,與大家分享。(本人還是在學習Tensorflow的初始階段,如果有什麼地方理解有誤,還請大家不吝指出。)

from tensorflow.examples.tutorials.mnist import input_data

# 下載mnist資料集至當前目錄下的MNIST_data資料夾,並讀取資料。
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# 輸出訓練集,測試集,驗證集圖片的shape和標籤的shape。
print(mnist.train.images.shape, mnist.train.labels.shape) print(mnist.test.images.shape, mnist.test.labels.shape) print(mnist.validation.images.shape, mnist.validation.labels.shape) import tensorflow as tf # 預設會話。 sess = tf.InteractiveSession() # 定義一個placeholder,用於儲存輸入的圖片的資訊。 # 由於圖片中的數值是0~1之間的浮點數,所以x的資料型別也應是tf.float32。
# 第二個引數表示x的維度,其中None表示不限制輸入的數量,之後的引數便是輸入的資料的維度, # 這裡的784表示輸入的是一個長度為784的一維向量。 x = tf.placeholder(tf.float32, [None, 784]) # 定義權重變數,該變數是一個784x10的矩陣,這裡將初始權重全部賦值為0。 W = tf.Variable(tf.zeros([784, 10])) # 定義偏置值變數,該變數是一個由10個元素組成的向量,同樣這裡的偏置值變數的初始值也全部被賦值為0。 b = tf.Variable(tf.zeros([10])) # 定義softmax層,輸入是一個10個元素組成的向量。
# tf.matmul是用於矩陣相乘的函式。 y = tf.nn.softmax(tf.matmul(x, W)+b) # 再次定義一個placeholder,用於儲存真實的圖片標籤。 y_ = tf.placeholder(tf.float32, [None, 10]) # 定義交叉熵,這是本程式需要使用的loss函式,我們的目的是使得這個loss函式儘可能的小。 cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) # 我們使用梯度下降的方法來優化引數,這裡把學習率設定為0.5,我們需要優化的函式是cross_entropy,即我們的loss函式。 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 初始化所有的全域性變數。 tf.global_variables_initializer().run() # 從這裡開始訓練我們建立好的模型。 # 從上面的公式中可以看出,我們建立的是一個全連線的模型,本質上是對矩陣乘法的優化。 # 我們訓練1000次。 for i in range(1000): # 使用mnist自帶的方法隨機產生100個數據。 batch_xs, batch_ys = mnist.train.next_batch(100) # 將這100個數據分別feed給上面我們定義的兩個placeholder,由於訓練模型。 train_step.run({x:batch_xs,y_:batch_ys}) # 建立評估模型正確預測的Graph。 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) # 定義準確率的計算公式。 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 將測試集資料傳遞給兩個placeholder,然後執行上述定義的準確率的公式,最後輸出準確率的值。 print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

執行結果如下:(因為使用了梯度下降的方法,因此每一次執行的結果或有不同,一般結果在0.92左右)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
(55000, 784) (55000, 10)
(10000, 784) (10000, 10)
(5000, 784) (5000, 10)
0.9184

相關推薦

no