1. 程式人生 > 程式設計 >pytorch自定義二值化網路層方式

pytorch自定義二值化網路層方式

任務要求:

自定義一個層主要是定義該層的實現函式,只需要過載Function的forward和backward函式即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定義二值化函式

class BinarizedF(Function):
  def forward(self,input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self,output_grad):
    input,= self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones,zeros)
    return input_grad

定義一個module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule,self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

進行測試

a = Variable(torch.randn(4,480,640),requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中,二值化函式部分也可以按照方式寫,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self,input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self,= self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上這篇pytorch自定義二值化網路層方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。