1. 程式人生 > 程式設計 >tensorflow 實現自定義梯度反向傳播程式碼

tensorflow 實現自定義梯度反向傳播程式碼

以sign函式為例:

tensorflow 實現自定義梯度反向傳播程式碼

sign函式可以對數值進行二值化,但在梯度反向傳播是不好處理,一般採用一個近似函式的梯度作為代替,如上圖的Htanh。在[-1,1]直接梯度為1,其他為0。

#使用修飾器,建立梯度反向傳播函式。其中op.input包含輸入值、輸出值,grad包含上層傳來的梯度
@tf.RegisterGradient("QuantizeGrad")
def sign_grad(op,grad):
 input = op.inputs[0]
 cond = (input>=-1)&(input<=1)
 zeros = tf.zeros_like(grad)
 return tf.where(cond,grad,zeros)
 
#使用with上下文管理器覆蓋原始的sign梯度函式
def binary(input):
 x = input
 with tf.get_default_graph().gradient_override_map({"Sign":'QuantizeGrad'}):
  x = tf.sign(x)
 return x
 
#使用
x = binary(x)

以上這篇tensorflow 實現自定義梯度反向傳播程式碼就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。