1. 程式人生 > 程式設計 >解決Pytorch 載入訓練好的模型 遇到的error問題

解決Pytorch 載入訓練好的模型 遇到的error問題

這是一個非常愚蠢的錯誤

debug的時候要好好看error資訊

提醒自己切記好好對待error!切記!切記!

-----------------------分割線----------------

pytorch 已經非常友好了 儲存模型和載入模型都只需要一條簡單的命令

#儲存整個網路和引數
torch.save(your_net,'save_name.pkl')
#載入儲存的模型
net = torch.load('save_name.pkl')

因為我比較懶我就想直接把整個網路都儲存下來,然後在test檔案中直接load一下不就好了?

就遭受了這樣的錯誤。看錯了error資訊,把‘Net'看成‘net'。報錯沒有屬性‘net'?這個不是我自己寫的變數名麼?

-----------------瞎搗鼓1h後(呵呵呵)----------------

回頭看error,沒有屬性‘Net',Net???

我當下明白過來,應該是test檔案中沒有把它import進來,test中就沒有任何關於Net的資訊。我直接把定義的Net複製進了test.py,就順利載入了訓練好的模型。

但是我也有一個疑問,我理解的把整個模型儲存難道不是把它的結構都儲存下來了麼?為什麼還要再把這個網路import一次?來自python、pytorch、面向物件程式設計三次元小白的疑惑,先存個疑,搞懂了再來回答。

接下來試試只儲存網路引數

#只儲存網路引數
torch.save(your_net.state_dict(),'save_name.pkl')
#載入儲存的模型
net.load_state_dict(torch.load('save_name.pkl'))

儲存網路引數

重新定義網路

報錯

想死。。。

仔細看了報錯資訊,以我小白的理解,我感覺儲存下來的可能只是單純的資料,而不是一個物件(沒有方法可以操作),或者該物件沒有.copy()方法,所以沒有辦法進行.copy(),那肯定是儲存哪裡出錯了。然後發現儲存部分程式碼寫錯了,改成

print一下 net.state_dict和net.state_dict(),前者輸出的是網路結構,後者才是網路的引數。

試著回答之前的問題,第二種儲存模型的方法只儲存了網路的引數(包括卷積層和全連線層每次的weight,bias),所以再載入模型的時候需要先定義網路無可厚非,就像訓練時候定義網路那樣定義就可以;而第一種儲存整個網路的方法,儲存了一個網路的例項(包括它的所有結構和引數),net是Net的一個例項,那為什麼還要有Class Net的定義呢,還是回答不了。。

那就繼續存疑,保持探究精神吧。。

以上這篇解決Pytorch 載入訓練好的模型 遇到的error問題就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。