tensorflow實現訓練變數checkpoint的儲存與讀取
阿新 • • 發佈:2020-02-10
1.儲存變數
先建立(在tf.Session()之前)saver
saver = tf.train.Saver(tf.global_variables(),max_to_keep=1) #max_to_keep這個保證只儲存最後一次training的訓練資料
然後在訓練的迴圈裡面
checkpoint_path = os.path.join(Path,'model.ckpt') saver.save(session,checkpoint_path,global_step=step) #這裡的step是迴圈訓練的次數,也就是第幾次迭代
以下儲存的變數檔案
2.變數讀取
1.若要直接恢復所有變數可以
saver = tf.train.Saver(tf.global_variables()) moudke_file=tf.train.latest_checkpoint('PATH') saver.restore(sess,moudke_file)
PATH是存放儲存變數的路徑,會自動找到最近儲存的變數檔案
2 若想讀取其中一部分變數值
def read_checkpoint(): w = [] checkpoint_path = '/home/ximao/models/resnet3/variable_logs/model.ckpt-17000' reader = tf.train.NewCheckpointReader(checkpoint_path) var = reader.get_variable_to_shape_map() for key in var: if 'weights' in key and 'conv' in key and 'Mo' not in key: print('tensorname:',key) # # print(reader.get_tensor(key))
3. 若想恢復其中一部分變數值到新網路
(1)首先你要先獲取你想要賦值新網路變數的變數名,這裡變數名不是一個字串,而是<name,shape,dtype>這樣的一個結構,
然後把你要賦值的元素轉為張量,最後把值賦給你得到變數名 如下:
var=[v for v in weight_pruned if v.op.name=='WRN/conv1/weights'] conv1_temp=tf.convert_to_tensor(conv1,dtype=tf.float32) sess.run(tf.assign(var[0],conv1_temp))
weight_pruned 存放的是你新網路中所有的變數
以上這篇tensorflow實現訓練變數checkpoint的儲存與讀取就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。