pytorch 使用載入訓練好的模型做inference
阿新 • • 發佈:2020-02-21
前提: 模型引數和結構是分別儲存的
1、 構建模型(# load model graph)
model = MODEL()
2、載入模型引數(# load model state_dict)
model.load_state_dict ( { k.replace('module.',''):v for k,v in torch.load(config.model_path,map_location=config.device).items() } ) model = self.model.to(config.device) * config.device 指定使用哪塊GPU或者CPU *k.replace('module.',''):v 防止torch.DataParallel訓練的模型出現載入錯誤
(解決RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cuda:1問題)
3、設定當前階段為inference(# predict)
model.eval()
以上這篇pytorch 使用載入訓練好的模型做inference就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。