1. 程式人生 > 程式設計 >淺談keras儲存模型中的save()和save_weights()區別

淺談keras儲存模型中的save()和save_weights()區別

今天做了一個關於keras儲存模型的實驗,希望有助於大家瞭解keras儲存模型的區別。

我們知道keras的模型一般儲存為字尾名為h5的檔案,比如final_model.h5。同樣是h5檔案用save()和save_weight()儲存效果是不一樣的。

我們用宇宙最通用的資料集MNIST來做這個實驗,首先設計一個兩層全連線網路:

inputs = Input(shape=(784,))
x = Dense(64,activation='relu')(inputs)
x = Dense(64,activation='relu')(x)
y = Dense(10,activation='softmax')(x)
 
model = Model(inputs=inputs,outputs=y)

然後,匯入MNIST資料訓練,分別用兩種方式儲存模型,在這裡我還把未訓練的模型也儲存下來,如下:

from keras.models import Model
from keras.layers import Input,Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train,y_train),(x_test,y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784,outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])
model.fit(x_train,y_train,batch_size=32,epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可見,我一共儲存了m1.h5,m2.h5,m3.h5 這三個h5檔案。那麼,我們來看看這三個玩意兒有什麼區別。首先,看看大小:

淺談keras儲存模型中的save()和save_weights()區別

m2表示save()儲存的模型結果,它既保持了模型的圖結構,又儲存了模型的引數。所以它的size最大的。

m1表示save()儲存的訓練前的模型結果,它儲存了模型的圖結構,但應該沒有儲存模型的初始化引數,所以它的size要比m2小很多。

m3表示save_weights()儲存的模型結果,它只儲存了模型的引數,但並沒有儲存模型的圖結構。所以它的size也要比m2小很多。

通過視覺化工具,我們發現:(開啟m1和m2均可以顯示出以下結構)

淺談keras儲存模型中的save()和save_weights()區別

而開啟m3的時候,視覺化工具報錯了。由此可以論證, save_weights()是不含有模型結構資訊的。

載入模型

兩種不同方法儲存的模型檔案也需要用不同的載入方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有載入m3.h5的時候,這段程式碼才會報錯。其他輸出如下:

淺談keras儲存模型中的save()和save_weights()區別

可見,由save()儲存下來的h5檔案才可以直接通過load_model()開啟!

那麼,我們儲存下來的引數(m3.h5)該怎麼開啟呢?

這就稍微複雜一點了,因為m3不含有模型結構資訊,所以我們需要把模型結構再描述一遍才可以載入m3,如下:

from keras.models import Model
from keras.layers import Input,Dense
 
inputs = Input(shape=(784,outputs=y)
model.load_weights('m3.h5')

以上把m3換成m1和m2也是沒有問題的!可見,save()儲存的模型除了佔用記憶體大一點以外,其他的優點太明顯了。所以,在不怎麼缺硬碟空間的情況下,還是建議大家多用save()來存。

注意!如果要load_weights(),必須保證你描述的有引數計算結構與h5檔案中完全一致!什麼叫有引數計算結構呢?就是有引數坑,直接填進去就行了。我們把上面的非引數結構換了一下,發現h5檔案依然可以載入成功,比如將softmax換成relu,依然不影響載入。

對於keras的save()和save_weights(),完全沒問題了吧

以上這篇淺談keras儲存模型中的save()和save_weights()區別就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。