1. 程式人生 > 程式設計 >Tensorflow實現在訓練好的模型上進行測試

Tensorflow實現在訓練好的模型上進行測試

Tensorflow可以使用訓練好的模型對新的資料進行測試,有兩種方法:第一種方法是呼叫模型和訓練在同一個py檔案中,中情況比較簡單;第二種是訓練過程和呼叫模型過程分別在兩個py檔案中。本文將講解第二種方法。

模型的儲存

tensorflow提供可儲存訓練模型的介面,使用起來也不是很難,直接上程式碼講解:

#網路結構
w1 = tf.Variable(tf.truncated_normal([in_units,h1_units],stddev=0.1))
b1 = tf.Variable(tf.zeros([h1_units]))
y = tf.nn.softmax(tf.matmul(w1,x) + b1)
tf.add_to_collection('network-output',y)

x = tf.placeholder(tf.float32,[None,in_units],name='x')
y_ = tf.placeholder(tf.float32,10],name='y_')
#損失函式與優化函式
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(rate).minimize(cross_entropy)

saver = tf.train.Saver()
with tf.Session() as sess: 
    sess.run(init) 
    saver.save(sess,"save/model.ckpt") 
    train_step.run({x: train_x,y_: train_y})

以上程式碼就完成了模型的儲存,值得注意的是下面這行程式碼

tf.add_to_collection('network-output',y)

這行程式碼儲存了神經網路的輸出,這個在後面使用匯入模型過程中起到關鍵作用。

模型的匯入

模型訓練並儲存後就可以匯入來評估模型在測試集上的表現,網上很多文章只用簡單的四則運算來做例子,讓人看的頭大。還是先上程式碼:

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('./model.ckpt.meta')
  saver.restore(sess,'./model.ckpt')# .data檔案
  pred = tf.get_collection('network-output')[0]

  graph = tf.get_default_graph()
  x = graph.get_operation_by_name('x').outputs[0]
  y_ = graph.get_operation_by_name('y_').outputs[0]

  y = sess.run(pred,feed_dict={x: test_x,y_: test_y})

講解一下關鍵的程式碼,首先是pred = tf.get_collection('pred_network')[0],這行程式碼獲得訓練過程中網路輸出的“介面”,簡單理解就是,通過tf.get_collection() 這個方法獲取了整個網路結構。獲得網路結構後我們就需要餵它對應的資料y = sess.run(pred,y_: test_y}) 在訓練過程中我們的輸入是

x = tf.placeholder(tf.float32,name='y_')

因此匯入模型後所需的輸入也要與之對應可使用以下程式碼獲得:

  x = graph.get_operation_by_name('x').outputs[0]
  y_ = graph.get_operation_by_name('y_').outputs[0]

使用模型的最後一步就是輸入測試集,然後按照訓練好的網路進行評估

  sess.run(pred,y_: test_y})

理解下這行程式碼,sess.run() 的函式原型為

run(fetches,feed_dict=None,options=None,run_metadata=None)

Tensorflow對 feed_dict 執行fetches操作,因此在匯入模型後的運算就是,按照訓練的網路計算測試輸入的資料。

以上這篇Tensorflow實現在訓練好的模型上進行測試就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。