1. 程式人生 > 程式設計 >淺談Tensorflow載入Vgg預訓練模型的幾個注意事項

淺談Tensorflow載入Vgg預訓練模型的幾個注意事項

寫這個部落格的關鍵Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16,bfloat16,float32,float64。本部落格將圍繞 載入圖片 和 儲存圖片到本地 來詳細解釋和解決上述的Bug及其引出來的一系列Bug。

載入圖片

首先,造成上述Bug的程式碼如下所示

image_path = "data/test.jpg" # 本地的測試圖片
 
image_raw = tf.gfile.GFile(image_path,'rb').read()
# 一定要tf.float(),否則會報錯
image_decoded = tf.image.decode_jpeg(image_raw)
 
# 擴充套件圖片的維度,從三維變成四維,符合Vgg19的輸入介面
image_expand_dim = tf.expand_dims(image_decoded,0)
 
# 定義Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim,'vgg19')
print(net)

上述程式碼是載入Vgg19預訓練模型,並傳入圖片得到所有層的特徵圖,具體的程式碼實現和原理講解可參考我的另一篇部落格:Tensorflow載入Vgg預訓練模型。那麼,為什麼程式碼會出現: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16,float64,這個Bug呢?

這句英文翻譯過來是指:傳遞的值型別是uint8,但是接受的引數型別必須是float的那幾種。故原因就是傳入值的資料型別錯了,那麼如何解決這個Bug呢,很簡單

image_path = "data/test.jpg" # 本地的測試圖片
 
image_raw = tf.gfile.GFile(image_path,'rb').read()
# 一定要tf.float(),否則會報錯
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 擴充套件圖片的維度,從三維變成四維,符合Vgg19的輸入介面
image_expand_dim = tf.expand_dims(image_decoded,'vgg19')
print(net)

這兩個程式碼塊唯一的變動就是:image_decoded結果在輸出前加了一個tf.float(),將其轉換為float型別。

在tensorflow API中,tf.image.decode_jpeg()預設讀取的圖片資料格式為unit8,而不是float。uint8資料的範圍在(0,255)中,正好符合圖片的畫素範圍(0,255)。但是,儲存在本地的Vgg19預訓練模型的資料介面為float,所以才造成了本文開頭的Bug。

這裡還要提一點,若是使用PIL的方法來載入圖片,則不會出現上述的Bug,因為通過PIL得到的圖片格式是float,而不是uint8,故不需要轉換。

很多同學可能會疑惑,若是強行改變了原圖片的資料格式,從uint8型別轉變成float,會不會導致資料改變或者出錯?故我做了下面這個實驗:

image_path = "data/3.jpg"
image_raw = tf.gfile.GFile(image_path,'rb').read()
image_unit8 = tf.image.decode_jpeg(image_raw)
image_float = tf.to_float(image_unit8)
 
with tf.Session() as sess:
 image_unit8_,image_float_ = sess.run([image_unit8,image_float])
 
print("image_unit8_",image_unit8_)
print("image_float_ ",image_float_ )

程式碼結果如下:

 image_unit8_
 [180,192,204],[183,195,207],[186,198,210],...,[191,205,218],[190,204,217]],image_float_ 
 [180.,192.,204.],[183.,195.,207.],[186.,198.,210.],[191.,205.,218.],[190.,204.,217.]],

可以看到,資料根本沒有變化,只是後面多加了個小數點,變得只有型別,而沒有強制改變值,故同學們不需要過度擔心。

儲存圖片到本地

在載入圖片的時候,為了使用儲存在本地的預訓練Vgg19模型,我們需要將讀取的圖片由uint8格式轉換成float格式。那若是我們想將已經轉換為float格式的圖片再儲存到本地,該怎麼做呢?

首先,我們根據上述的文字的意思讀取圖片,並且將其轉換為float格式,在將讀取的圖片再次儲存到本地之前,我們首先視覺化一下轉換格式後的圖片,程式碼如下:

import tensorflow as tf
from matplotlib import pyplot as plt
image_path = "data/boat.jpg"
 
image_raw = tf.gfile.GFile(image_path,'rb').read()
image_decoded = tf.image.decode_jpeg(image_raw)
image_decoded = tf.to_float(image_decoded)
 
with tf.Session() as sess:
 image_decoded_ = sess.run(image_decoded)
 plt.imshow(image_decoded_)
 plt.show()

生成的圖片如下圖所示:

淺談Tensorflow載入Vgg預訓練模型的幾個注意事項

左邊是原圖,右邊是轉換為float格式的圖片,可見將圖片轉換為float格式,雖然數值沒有造成太大影響,但是若想將圖片儲存到本地就會出現問題。

說了這麼多,只為了說一點,在儲存圖片到本地之前,需要將其格式從float轉回uint8,否則會造成一系列錯誤:圖片顯示異常,API報錯等。正確的儲存程式碼如下:

save_path = "data/boat_copy.jpg"
image_uint = tf.cast(image_decoded,tf.uint8)
with tf.Session() as sess:
 with open(save_path,'wb') as img:
 image_saved = sess.run(tf.image.encode_jpeg(image_uint))
 img.write(image_saved)

其中只有一句話最關鍵,即 tf.cast(image_decoded,tf.uint8)。

以上這篇淺談Tensorflow載入Vgg預訓練模型的幾個注意事項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。