1. 程式人生 > 程式設計 >tensorflow實現訓練變數checkpoint的儲存與讀取

tensorflow實現訓練變數checkpoint的儲存與讀取

1.儲存變數

先建立(在tf.Session()之前)saver

saver = tf.train.Saver(tf.global_variables(),max_to_keep=1)  #max_to_keep這個保證只儲存最後一次training的訓練資料

然後在訓練的迴圈裡面

checkpoint_path = os.path.join(Path,'model.ckpt') saver.save(session,checkpoint_path,global_step=step) #這裡的step是迴圈訓練的次數,也就是第幾次迭代

以下儲存的變數檔案

tensorflow實現訓練變數checkpoint的儲存與讀取

2.變數讀取

1.若要直接恢復所有變數可以

saver = tf.train.Saver(tf.global_variables())
moudke_file=tf.train.latest_checkpoint('PATH')
saver.restore(sess,moudke_file)

PATH是存放儲存變數的路徑,會自動找到最近儲存的變數檔案

2 若想讀取其中一部分變數值

def read_checkpoint():
  w = []
  checkpoint_path = '/home/ximao/models/resnet3/variable_logs/model.ckpt-17000'
  reader = tf.train.NewCheckpointReader(checkpoint_path)
  var = reader.get_variable_to_shape_map()
  for key in var:
    if 'weights' in key and 'conv' in key and 'Mo' not in key:
      print('tensorname:',key)
  #   # print(reader.get_tensor(key))

3. 若想恢復其中一部分變數值到新網路

(1)首先你要先獲取你想要賦值新網路變數的變數名,這裡變數名不是一個字串,而是<name,shape,dtype>這樣的一個結構,

然後把你要賦值的元素轉為張量,最後把值賦給你得到變數名 如下:

var=[v for v in weight_pruned if v.op.name=='WRN/conv1/weights']
conv1_temp=tf.convert_to_tensor(conv1,dtype=tf.float32)
sess.run(tf.assign(var[0],conv1_temp))

weight_pruned 存放的是你新網路中所有的變數

以上這篇tensorflow實現訓練變數checkpoint的儲存與讀取就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。