pytorch模型中間層特徵的提取
阿新 • • 發佈:2019-01-03
定義一個特徵提取的類:
#中間特徵提取 class FeatureExtractor(nn.Module): def __init__(self, submodule, extracted_layers): super(FeatureExtractor,self).__init__() self.submodule = submodule self.extracted_layers= extracted_layers def forward(self, x): outputs = [] for name, module in self.submodule._modules.items(): if name is "fc": x = x.view(x.size(0), -1) x = module(x) print(name) if name in self.extracted_layers: outputs.append(x) return outputs
#輸入資料
test_loader=DataLoader(test_dataset,batch_size=1)
img,label=iter(test_loader).next()
img, label = Variable(img, volatile=True), Variable(label, volatile=True)
#特徵輸出 myresnet=resnet18(pretrained=False) myresnet.load_state_dict(torch.load('cafir_resnet18_1.pkl')) exact_list=["conv1","layer1","avgpool"] myexactor=FeatureExtractor(myresnet,exact_list) x=myexactor(img)
#特徵輸出視覺化
import matplotlib.pyplot as plt
for i in range(64):
ax = plt.subplot(8, 8, i + 1)
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
plt.show()