tensorflow 實現自定義梯度反向傳播程式碼
阿新 • • 發佈:2020-02-10
以sign函式為例:
sign函式可以對數值進行二值化,但在梯度反向傳播是不好處理,一般採用一個近似函式的梯度作為代替,如上圖的Htanh。在[-1,1]直接梯度為1,其他為0。
#使用修飾器,建立梯度反向傳播函式。其中op.input包含輸入值、輸出值,grad包含上層傳來的梯度 @tf.RegisterGradient("QuantizeGrad") def sign_grad(op,grad): input = op.inputs[0] cond = (input>=-1)&(input<=1) zeros = tf.zeros_like(grad) return tf.where(cond,grad,zeros) #使用with上下文管理器覆蓋原始的sign梯度函式 def binary(input): x = input with tf.get_default_graph().gradient_override_map({"Sign":'QuantizeGrad'}): x = tf.sign(x) return x #使用 x = binary(x)
以上這篇tensorflow 實現自定義梯度反向傳播程式碼就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。