1. 程式人生 > 實用技巧 >tensorflow2.0——手寫資料集預測

tensorflow2.0——手寫資料集預測

import tensorflow as tf
import numpy as np
import matplotlib.pylab as plt

plt.rcParams["font.family"] = 'SimHei'                          # 將字型改為中文
plt.rcParams['axes.unicode_minus'] = False                      # 設定了中文字型預設後,座標的"-"號無法顯示,設定這個引數就可以避免

# 載入手寫數字資料
mnist = tf.keras.datasets.mnist
(train_x, train_y), (test_x, test_y) 
= mnist.load_data() # 將0到9轉化為one-hot編碼 y_hot = np.zeros((10, 10)) for i in range(y_hot.shape[0]): y_hot[i, i] = 1 # print('y_hot:', y_hot) # 將標記值轉化為one-hot編碼 train_Y = np.zeros((train_y.shape[0], 10)) for i in range(train_y.shape[0]): train_Y[i] = y_hot[train_y[i]] print('train_Y:', train_Y, train_Y.shape)
# 將28*28展開為784*1 # 訓練集 train_X1 = np.ones((train_x.shape[0], 784)) ones = np.ones((train_x.shape[0], 1)) print('ones.shape:', ones.shape) for i in range(train_x.shape[0]): train_X1[i] = train_x[i].reshape([1, -1]) print('train_X1.shape:', train_X1.shape) train_X = tf.concat([train_X1, ones], axis=1)
# 測試集 test_X1 = np.ones((test_x.shape[0], 784)) ones = np.ones((test_x.shape[0], 1)) for i in range(test_x.shape[0]): test_X1[i] = test_x[i].reshape([1, -1]) test_X = tf.concat([test_X1, ones], axis=1) # 將標記資料轉化為列向量 train_y = train_y.reshape(-1,1) test_y = test_y.reshape(-1,1) # 儲存準確值資料 acc_train = [] acc_test = [] # 設定超引數 iter = 1500 # 迭代次數 learn_rate = 5e-12 # 學習率 # 初始化訓練引數 w = tf.Variable(np.random.randn(785, 10)*0.0001) print('初試w:',w,w.shape) for i in range(iter): with tf.GradientTape() as tape: y_p = 1/(1+tf.math.exp(-tf.matmul(train_X,w))) y_p_test = 1 / (1 + tf.math.exp(-tf.matmul(test_X, w))) loss = tf.reduce_sum(-(train_Y * tf.math.log(y_p)+(1 - train_Y)*tf.math.log(1-y_p))) # print('loss:',loss) dl_dw = tape.gradient(loss,w) w.assign_sub(learn_rate * dl_dw) if i % 20 == 0: print('i:{}, loss:{}, w:{}'.format(i,loss,w)) # print('y_p:',y_p) # 訓練集準確率 y_p_round = tf.round(y_p) # 將預測資料進行四捨五入變成one-hot編碼格式 p_y = tf.reshape(tf.argmax(y_p_round, 1), (-1, 1)) # 將one-hot轉化為預測數字 is_right = tf.equal(p_y, train_y) # 比對是否預測正確 right_int = tf.cast(is_right, tf.int8) # 將bool型轉化為0,1 acc = tf.reduce_mean(tf.cast(right_int, dtype=tf.float32)) # 求準確陣列的平均值,也就是準確率 acc_train.append(acc) print('acc:', acc) # 測試集準確率 y_p_test_round = tf.round(y_p_test) p_y_test = tf.reshape(tf.argmax(y_p_test_round, 1), (-1, 1)) is_right_test = tf.equal(p_y_test, test_y) right_int_test = tf.cast(is_right_test, tf.int8) acc2 = tf.reduce_mean(tf.cast(right_int_test, dtype=tf.float32)) acc_test.append(acc2) print('acc2:', acc2) print() # 畫出準確率的訓練折線圖 plt.plot(acc_train,label = '訓練集正確率') plt.plot(acc_test,label = '測試集正確率') plt.legend() plt.show()