pytorch藉助tensorboard實現模型視覺化
阿新 • • 發佈:2018-11-19
補充 : 剛發現貌似sqrt操作是不支援的
python庫:
pytorch(>=0.3) , onnx, tensorboardX
原理:
Open Neural Network Exchange (ONNX)是開放生態系統的第一步,它使人工智慧開發人員可以在專案的發展過程中選擇合適的工具;ONNX為AI models提供了一種開源格式。它定義了一個可以擴充套件的計算圖模型,同時也定義了內建操作符和標準資料型別。最初我們關注的是推理(評估)所需的能力。
Caffe2, PyTorch, Microsoft Cognitive Toolkit, Apache MXNet 和其他工具都在對ONNX進行支援。在不同的框架之間實現互操作性,並簡化從研究到產品化的過程,將提高人工智慧社群的創新速度。
簡單來說就是藉助onnx將pytorch的模型存為model.proto的檔案,然後藉助於tensorboardX這個工具將model.proto轉換為tensorboar的graph.
程式碼:
#對於pytorch0.3以上的版本 import tensorboardX import torch from torchvision.models import resnet34 import torch.onnx x=torch.autograd.Variable(torch.rand(1,3,224,224)) #隨便定義一個輸入 model=resnet34() proto=torch.onnx.export(model,x,"resnet34.proto",verbose=True) #將model的結構和引數全部儲存為 resnet32.proto writer=tensorboardX.SummaryWriter("./logs/") #定義一個tensorboardX的寫物件 writer.add_graph_onnx("./resnet34.proto") #將proto格式的檔案轉換為tensorboard中的graph
對於pytorch 0.2來說可以直接來畫:
import tensorboardX import torch from torchvision.models import resnet34 import torch.onnx x=torch.autograd.Variable(torch.rand(1,3,224,224)) #隨便定義一個輸入 model=resnet34() writer=tensorboardX.SummaryWriter("./logs/") #定義一個tensorboardX的寫物件 writer.add_graph(model,x,verbose=True) #將proto格式的檔案轉換為tensorboard中的graph
效果如下 ,確實有點醜,不如tensorflow那樣五顏六色,也沒有更加詳細的操作:
拉近的圖片:
補充,剛才有人說好像max_pool2d是不支援的,我自己的測試時可以的,建議檢查一下tensorboardX和ONNX的版本,程式碼如下: 我的tensorboardX版本是1.4的,onnx版本是1.3.0
import torch
import torch.nn.functional as F
import torch.onnx
import tensorboardX
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
super(ResNet, self).__init__()
def forward(self, x):
#這兒就是我加的操作
x=F.max_pool2d(x,kernel_size=7)
return x
def resnet50():
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3])
return model
if __name__=="__main__":
x=torch.autograd.Variable(torch.rand(1,3,224,224)) #隨便定義一個輸入
model=resnet50()
proto=torch.onnx.export(model,x,"resnet50.proto",verbose=True) #將model的結構和引數全部儲存為 resnet32.proto
writer=tensorboardX.SummaryWriter("./logs/") #定義一個tensorboardX的寫物件
writer.add_graph_onnx("./resnet50.proto")