1. 程式人生 > >Tensorflow深度學習之三十三:tf.scatter_update

Tensorflow深度學習之三十三:tf.scatter_update

一、tf.scatter_update

tf.scatter_update(
    ref,
    indices,
    updates,
    use_locking=True,
    name=None
)

   Applies sparse updates to a variable reference.
   將稀疏更新應用於變數引用。

   This operation computes
   該函式的計算過程如下:

	# Scalar indices
	ref[indices, ...] = updates[...]

	# Vector indices (for each i)
ref[indices[i], ...] = updates[i, ...] # High rank indices (for each i, ..., j) ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]

   This operation outputs ref after the update is done. This makes it easier to chain operations that need to use the reset value.
   更新完成後,此操作輸出ref。 這樣可以更容易地連結需要使用重置值的操作。

   If values in ref is to be updated more than once, because there are duplicate entries in indices, the order at which the updates happen for each value is undefined.
   如果ref中的值要多次更新,因為索引中存在重複條目,則每個值的更新發生順序是不確定的。

   Requires updates.shape = indices.shape + ref.shape[1:].
圖片來源自Tensor Flow官方網站

二、引數

引數
ref A Variable
.
一個Variable
indices A Tensor. Must be one of the following types: int32, int64. A tensor of indices into the first dimension of ref. 一個Tensor,必須為以下的資料型別: int32int64表示在ref的第一維中的索引的張量。
updates A Tensor. Must have the same type as ref. A tensor of updated values to store in ref. 一個Tensor,必須和ref擁有相同的資料型別。表示一個要儲存在ref中的更新值的張量。
use_locking An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. 一個可選的bool值。 預設為True。 如果為True,則分配將受鎖保護; 否則行為未定義,但可能表現出較少的爭用。
name A name for the operation (optional). 名稱,可選。

返回值
   Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.
   與ref有相同資料型別。 返回以方便那些在更新完成後要使用更新值的操作。(注:返回的是一個Variable,而不是和tf.scatter_nd_update一樣返回是一個Tensor。!!)

三、程式碼
   該函式的是給定需要待更新的矩陣的第一維索引和需要更新的資料然後根據這些資料進行更新。
   根據這個公式updates.shape = indices.shape + ref.shape[1:],可以看出該函式和tf.scatter_nd_update 函式最大的區別,前者只作用於矩陣的第一維,後者可以作用於矩陣的任意多個維度。
   綜上,理解了上述的公式,這個函式就很容易理解了。

import tensorflow as tf
import tensorflow.contrib.eager as tfe

tf.enable_eager_execution()

ref = tfe.Variable(initial_value=[[0, 0, 0, 0], [0, 0, 0, 0]])
indices = tf.constant([0])
updates = tf.constant([[1, 98, 20, 102]])
update = tf.scatter_update(ref, indices, updates)

print(update)

   結果如下:

<tf.Variable '' shape=(2, 4) dtype=int32, numpy=
array([[  1,  98,  20, 102],
       [  0,   0,   0,   0]])>

   三維矩陣更新:

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import numpy as np

tf.enable_eager_execution()

ref = tfe.Variable(np.zeros(shape=[4, 4, 3], dtype=np.float32))
indices = tf.constant([0])
updates = tf.constant([np.random.random(size=[4, 3])], dtype=tf.float32)
update = tf.scatter_update(ref, indices, updates)

print(update)

   結果如下:

<tf.Variable '' shape=(4, 4, 3) dtype=float32, numpy=
array([[[0.13318056, 0.83603495, 0.8232899 ],
        [0.02964316, 0.8545541 , 0.27696434],
        [0.4880769 , 0.23017927, 0.64292145],
        [0.07073301, 0.10755321, 0.347981  ]],

       [[0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        ]]], dtype=float32)>