1. 程式人生 > 程式設計 >pytorch 使用載入訓練好的模型做inference

pytorch 使用載入訓練好的模型做inference

前提: 模型引數和結構是分別儲存的

1、 構建模型(# load model graph)

model = MODEL()

2、載入模型引數(# load model state_dict)

 model.load_state_dict
 (
 {

 k.replace('module.',''):v for k,v in

 torch.load(config.model_path,map_location=config.device).items()

 }
 )
 
model = self.model.to(config.device)

* config.device 指定使用哪塊GPU或者CPU  

*k.replace('module.',''):v 防止torch.DataParallel訓練的模型出現載入錯誤

(解決RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cuda:1問題)

3、設定當前階段為inference(# predict)

model.eval()

以上這篇pytorch 使用載入訓練好的模型做inference就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。