淺談Tensorflow載入Vgg預訓練模型的幾個注意事項
寫這個部落格的關鍵Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16,bfloat16,float32,float64。本部落格將圍繞 載入圖片 和 儲存圖片到本地 來詳細解釋和解決上述的Bug及其引出來的一系列Bug。
載入圖片
首先,造成上述Bug的程式碼如下所示
image_path = "data/test.jpg" # 本地的測試圖片 image_raw = tf.gfile.GFile(image_path,'rb').read() # 一定要tf.float(),否則會報錯 image_decoded = tf.image.decode_jpeg(image_raw) # 擴充套件圖片的維度,從三維變成四維,符合Vgg19的輸入介面 image_expand_dim = tf.expand_dims(image_decoded,0) # 定義Vgg19模型 vgg19 = VGG19(data_path) net = vgg19.feed_forward(image_expand_dim,'vgg19') print(net)
上述程式碼是載入Vgg19預訓練模型,並傳入圖片得到所有層的特徵圖,具體的程式碼實現和原理講解可參考我的另一篇部落格:Tensorflow載入Vgg預訓練模型。那麼,為什麼程式碼會出現: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16,float64,這個Bug呢?
這句英文翻譯過來是指:傳遞的值型別是uint8,但是接受的引數型別必須是float的那幾種。故原因就是傳入值的資料型別錯了,那麼如何解決這個Bug呢,很簡單
image_path = "data/test.jpg" # 本地的測試圖片 image_raw = tf.gfile.GFile(image_path,'rb').read() # 一定要tf.float(),否則會報錯 image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw)) # 擴充套件圖片的維度,從三維變成四維,符合Vgg19的輸入介面 image_expand_dim = tf.expand_dims(image_decoded,'vgg19') print(net)
這兩個程式碼塊唯一的變動就是:image_decoded結果在輸出前加了一個tf.float(),將其轉換為float型別。
在tensorflow API中,tf.image.decode_jpeg()預設讀取的圖片資料格式為unit8,而不是float。uint8資料的範圍在(0,255)中,正好符合圖片的畫素範圍(0,255)。但是,儲存在本地的Vgg19預訓練模型的資料介面為float,所以才造成了本文開頭的Bug。
這裡還要提一點,若是使用PIL的方法來載入圖片,則不會出現上述的Bug,因為通過PIL得到的圖片格式是float,而不是uint8,故不需要轉換。
很多同學可能會疑惑,若是強行改變了原圖片的資料格式,從uint8型別轉變成float,會不會導致資料改變或者出錯?故我做了下面這個實驗:
image_path = "data/3.jpg" image_raw = tf.gfile.GFile(image_path,'rb').read() image_unit8 = tf.image.decode_jpeg(image_raw) image_float = tf.to_float(image_unit8) with tf.Session() as sess: image_unit8_,image_float_ = sess.run([image_unit8,image_float]) print("image_unit8_",image_unit8_) print("image_float_ ",image_float_ )
程式碼結果如下:
image_unit8_ [180,192,204],[183,195,207],[186,198,210],...,[191,205,218],[190,204,217]],image_float_ [180.,192.,204.],[183.,195.,207.],[186.,198.,210.],[191.,205.,218.],[190.,204.,217.]],
可以看到,資料根本沒有變化,只是後面多加了個小數點,變得只有型別,而沒有強制改變值,故同學們不需要過度擔心。
儲存圖片到本地
在載入圖片的時候,為了使用儲存在本地的預訓練Vgg19模型,我們需要將讀取的圖片由uint8格式轉換成float格式。那若是我們想將已經轉換為float格式的圖片再儲存到本地,該怎麼做呢?
首先,我們根據上述的文字的意思讀取圖片,並且將其轉換為float格式,在將讀取的圖片再次儲存到本地之前,我們首先視覺化一下轉換格式後的圖片,程式碼如下:
import tensorflow as tf from matplotlib import pyplot as plt image_path = "data/boat.jpg" image_raw = tf.gfile.GFile(image_path,'rb').read() image_decoded = tf.image.decode_jpeg(image_raw) image_decoded = tf.to_float(image_decoded) with tf.Session() as sess: image_decoded_ = sess.run(image_decoded) plt.imshow(image_decoded_) plt.show()
生成的圖片如下圖所示:
左邊是原圖,右邊是轉換為float格式的圖片,可見將圖片轉換為float格式,雖然數值沒有造成太大影響,但是若想將圖片儲存到本地就會出現問題。
說了這麼多,只為了說一點,在儲存圖片到本地之前,需要將其格式從float轉回uint8,否則會造成一系列錯誤:圖片顯示異常,API報錯等。正確的儲存程式碼如下:
save_path = "data/boat_copy.jpg" image_uint = tf.cast(image_decoded,tf.uint8) with tf.Session() as sess: with open(save_path,'wb') as img: image_saved = sess.run(tf.image.encode_jpeg(image_uint)) img.write(image_saved)
其中只有一句話最關鍵,即 tf.cast(image_decoded,tf.uint8)。
以上這篇淺談Tensorflow載入Vgg預訓練模型的幾個注意事項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。