pytorch自定義二值化網路層方式
阿新 • • 發佈:2020-01-09
任務要求:
自定義一個層主要是定義該層的實現函式,只需要過載Function的forward和backward函式即可,如下:
import torch from torch.autograd import Function from torch.autograd import Variable
定義二值化函式
class BinarizedF(Function): def forward(self,input): self.save_for_backward(input) a = torch.ones_like(input) b = -torch.ones_like(input) output = torch.where(input>=0,a,b) return output def backward(self,output_grad): input,= self.saved_tensors input_abs = torch.abs(input) ones = torch.ones_like(input) zeros = torch.zeros_like(input) input_grad = torch.where(input_abs<=1,ones,zeros) return input_grad
定義一個module
class BinarizedModule(nn.Module): def __init__(self): super(BinarizedModule,self).__init__() self.BF = BinarizedF() def forward(self,input): print(input.shape) output =self.BF(input) return output
進行測試
a = Variable(torch.randn(4,480,640),requires_grad=True) output = BinarizedModule()(a) output.backward(torch.ones(a.size())) print(a) print(a.grad)
其中,二值化函式部分也可以按照方式寫,但是速度慢了0.05s
class BinarizedF(Function): def forward(self,input): self.save_for_backward(input) output = torch.ones_like(input) output[input<0] = -1 return output def backward(self,= self.saved_tensors input_grad = output_grad.clone() input_abs = torch.abs(input) input_grad[input_abs>1] = 0 return input_grad
以上這篇pytorch自定義二值化網路層方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。