1. 程式人生 > >對tensorflow中張量tensor的理解與tf.argmax()函式的用法

對tensorflow中張量tensor的理解與tf.argmax()函式的用法

對tensorflow中張量tensor的理解:

一維張量:

如a=[1., 2., 3., 0., 9., ],其shape為(5,),故當我們選擇維度0時(張量的維度總是從第0個維度開始),實際上是在a的最外層括號上進行操作。

我們畫圖來表示:

二維張量:

如b=[[1, 2, 3], [3, 2, 1], [4, 5, 6], [6, 5, 4]],其shape為(4, 3),當我們選擇維度0時,是在b的最外層括號上進行操作;當我們選擇維度1時,是在b的第二層括號內進行操作。

如果我們把b寫成矩陣的形式的話,當我們選擇維度0時,就是在矩陣的列上進行操作;當我們選擇維度1時,就是在矩陣的行上進行操作。

我們畫一幅圖來表示:

三維張量:

如c = tf.constant([[[1, 3, 1, 4], [5, 3, 2, 1], [1, 2, 2, 4]], [[4, 2, 3, 1], [5, 3, 1, 9], [3, 7, 2, 3]]]),其shape為(2,3,4)。類似地,當我們選擇維度0時,也是在最外層括號上進行操作;當我們選擇維度1時,是在第二層括號上進行操作;當我們選擇維度2時,是在最裡層的括號上進行操作。

為了更加直觀,我們畫一幅圖來表示:

tf.argmax()函式的用法:

函式形式:

tf.argmax(input=tensor,dimention=axis)

該函式返回指定的張量tensor中指定維度上的最大值/最小值的下標(即位置)。dimension=0則查詢tensor上維度0上的最大值(如果是一維陣列,維度0就是行,如果是二維陣列,維度0就是列)dimension=1則查詢tensor上維度1上的最大值(如果是二維陣列,維度1就是行) 。dimension = 2、3、4...,即為多維張量時,按同理推斷。

如果tensor是一個向量,那就返回一個值,如果是一個矩陣,那就返回一個向量,這個向量的每一個維度都是相對應矩陣指定的維度上的最大值元素的索引號。

以上面的a,b,c張量舉例:

import tensorflow as tf

with tf.Session() as sess:
	print("建立一個一維張量a:")
	a = tf.constant([1., 2., 3., 0., 9., ])
	print(a, a.shape)
	print("建立一個二維張量b,b是一個4X3矩陣,矩陣的每個元素是一個值,b有2個維度:")
	b = tf.constant([[1, 2, 3], [3, 2, 1], [4, 5, 6], [6, 5, 4]])
	print(b, b.shape)
	print("建立一個三維張量c,c是一個2X3矩陣,矩陣的每個元素時一個有4個元素的陣列,c有3個維度:")
	c = tf.constant([[[1, 3, 1, 4], [5, 3, 2, 1], [1, 2, 2, 4]],
					 [[4, 2, 3, 1], [5, 3, 1, 9], [3, 7, 2, 3]]])
	print(c, c.shape)
	print("查詢一維張量a的最大值的下標:")
	print(sess.run(tf.argmax(a, 0)))
	print("查詢二維張量b的每列最大值的下標:")
	print(sess.run(tf.argmax(b, 0)))
	print("查詢二維張量b的每行最大值的下標:")
	print(sess.run(tf.argmax(b, 1)))
	print("查詢三維張量c的維度0上最大值的下標:")
	print(sess.run(tf.argmax(c, 0)))
	print("查詢三維張量c的維度1上最大值的下標:")
	print(sess.run(tf.argmax(c, 1)))
	print("查詢三維張量c的維度2上最大值的下標:")
	print(sess.run(tf.argmax(c, 2)))

執行結果如下:

建立一個一維張量a:
Tensor("Const:0", shape=(5,), dtype=float32) (5,)
建立一個二維張量b,b是一個4X3矩陣,矩陣的每個元素是一個值,b有2個維度:
Tensor("Const_1:0", shape=(4, 3), dtype=int32) (4, 3)
建立一個三維張量c,c是一個2X3矩陣,矩陣的每個元素時一個有4個元素的陣列,c有3個維度:
Tensor("Const_2:0", shape=(2, 3, 4), dtype=int32) (2, 3, 4)
查詢一維張量a的最大值的下標:
4
查詢二維張量b的每列最大值的下標:
[3 2 2]
查詢二維張量b的每行最大值的下標:
[2 0 2 0]
查詢三維張量c的維度0上最大值的下標:
[[1 0 1 0]
 [0 0 0 1]
 [1 1 0 0]]
查詢三維張量c的維度1上最大值的下標:
[[1 0 1 0]
 [1 2 0 1]]
查詢三維張量c的維度2上最大值的下標:
[[3 0 3]
 [0 3 1]]

Process finished with exit code 0

對於一維陣列,dimension=0時tf.argmax函式就是對唯一的一個維度行上取最大值下標;

對於二維陣列,dimension=0時tf.argmax函式是對行上取最大值下標,dimension=1時tf.argmax函式是對列上取最大值下標;

對於三維陣列c(shape=(2,3,4)),寫出來是這樣的形式:

如上圖,各個維度的操作方向如圖所示,因此:

dimension=0時tf.argmax函式是對第一種c的寫法的縱向的列上取最大值下標,我們可以發現一共取了12個值;

dimension=1時tf.argmax函式是對第二種c的寫法的縱向的列上取最大值下標(注意是對第二層括號內),因此一共取了8個值;

dimension=2時tf.argmax函式是對第一種c的寫法的行方向上的最內部數組裡取最大值的下標,一共取了6個值。

在實際應用中,tf.argmax()往往和tf.equal()在tensorflow的模型中一起使用,用來計算模型的準確度。

如:

correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))