Pytorch學習(三)--用50行程式碼搭建ResNet
阿新 • • 發佈:2019-01-05
#------------------------------用50行程式碼搭建ResNet------------------------------------------- from torch import nn import torch as t from torch.nn import functional as F class ResidualBlock(nn.Module): #實現子module: Residual Block def __init__(self,inchannel,outchannel,stride=1,shortcut=None): super(ResidualBlock,self).__init__() self.left=nn.Sequential( nn.Conv2d(inchannel,outchannel,3,stride,1,bias=False), nn.BatchNorm2d(outchannel), nn.ReLU(inplace=True), nn.Conv2d(outchannel,outchannel,3,1,1,bias=False), nn.BatchNorm2d(outchannel) ) self.right=shortcut def forward(self,x): out=self.left(x) residual=x if self.right is None else self.right(x) out+=residual return F.relu(out) class ResNet(nn.Module): #實現主module:ResNet34 #ResNet34包含多個layer,每個layer又包含多個residual block #用子module實現residual block , 用 _make_layer 函式實現layer def __init__(self,num_classes=1000): super(ResNet,self).__init__() self.pre=nn.Sequential( nn.Conv2d(3,64,7,2,3,bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(3,2,1) ) #重複的layer,分別有3,4,6,3個residual block self.layer1=self._make_layer(64,64,3) self.layer2=self._make_layer(64,128,4,stride=2) self.layer3=self._make_layer(128,256,6,stride=2) self.layer4=self._make_layer(256,512,3,stride=2) #分類用的全連線 self.fc=nn.Linear(512,num_classes) def _make_layer(self,inchannel,outchannel,block_num,stride=1): #構建layer,包含多個residual block shortcut=nn.Sequential( nn.Conv2d(inchannel,outchannel,1,stride,bias=False), nn.BatchNorm2d(outchannel)) layers=[ ] layers.append(ResidualBlock(inchannel,outchannel,stride,shortcut)) for i in range(1,block_num): layers.append(ResidualBlock(outchannel,outchannel)) return nn.Sequential(*layers) def forward(self,x): x=self.pre(x) x=self.layer1(x) x=self.layer2(x) x=self.layer3(x) x=self.layer4(x) x=F.avg_pool2d(x,7) x=x.view(x.size(0),-1) return self.fc(x)
model=ResNet()
input=t.autograd.Variable(t.randn(1,3,224,224))
o=model(input)
print(o)
大致框架算是理解了,但是細節部分比如卷積層的輸入輸出的大小之類的,還需要仔細研究。
Pytorch學習系列(一)至(四)均摘自《深度學習框架PyTorch入門與實踐》陳雲