1. 程式人生 > 其它 >pytorch載入模型時出現.....xxx.pth is a zip archive (did you mean to use torch.jit.load()?)

pytorch載入模型時出現.....xxx.pth is a zip archive (did you mean to use torch.jit.load()?)

技術標籤:PythonComputer Visionpytorchpython

Bug原因: 訓練和測試的torch版本不一致。訓練的時候是1.x,測試的時候是1.m。
解決辦法: 先在1.x版本下載入模型,然後在儲存的時候設定use_new_zipfile_serialization=False 就行了。

#torch_version==1.x
import torch
from models import net
checkpoint = 'xxx.pth'

model = net()
model.load_state_dict(torch.load(checkpoint))
model.
eval() torch.save(model.state_dict(), model_path, use_new_zipfile_serialization=False)