1. 程式人生 > 實用技巧 >神經網路學習-tensorflow2.0-tensor的合併與分割

神經網路學習-tensorflow2.0-tensor的合併與分割

1.tf.concat([a,b],axis=a):在第a維度上將tensor a與b進行合併。

example:

input:

a=tf.ones([2,5,6])
b=tf.ones([3,5,6])
c=tf.concat([a,b],axis=0)
print(c.shape)

output:

(5, 5, 6)

2.tf.stack([a,b],axis=a):創造新的維度a,在第a維度上將tensor a與b進行合併,原來在此位置的維度及其之後的維度向右移動。

(*注意*)合併時要求兩個tensor的現有的所有維度值都相等。

example:

input:

a=tf.random.normal([4,28,28,3])

b=tf.random.normal([4,28,28,3])
c=tf.stack([a,b],axis=2)
print(c.shape)

output:

(4, 28, 2, 28, 3)

tf.unstack(tensor,axis=a):可將原tensor拆分成多個新的tensor,這多個新的tensor數量等於維度a的值,且相較於原來的tensor消去了一個維度a。

example:

input:

a=tf.random.normal([2,28,28,3])
b=tf.random.normal([2,28,28,3])
c=tf.stack([a,b],axis=2)
d,e=tf.unstack(c,axis=0)

print(c.shape)
print(d.shape)

output:

(4, 28, 2, 28, 3)

(28, 2, 28, 3)

input:

a=tf.random.normal([2,28,28,3])
b=tf.random.normal([2,28,28,3])
c=tf.stack([a,b],axis=2)
d,e,f=tf.unstack(c,axis=-1)
print(c.shape)
print(d.shape)

output:

(2, 28, 2, 28, 3)
(2, 28, 2, 28)

3.tf.split(tensor,axis=a,num_or_size_splits=m)或tf.split(tensor,axis=a,num_or_size_splits=[m,n,k ...]):將tensor在第a維度上等分為m份或將其等分為在a維度上數值為m,n,k ...的若干個tensor,其中中括號中的數字和必須與原tensor在a維度上的數值相等。

example:

input:

a=tf.random.normal([2,28,28,3])
b=tf.random.normal([2,28,28,3])
c=tf.stack([a,b],axis=2)
re=tf.split(c,axis=-1,num_or_size_splits=[1,2])
re1=tf.split(c,axis=1,num_or_size_splits=2)
print(c.shape)
print(re[0].shape,'\n',re[1].shape)
print(re1[0].shape,'\n',re1[1].shape)

output:

(2, 28, 2, 28, 3)
(2, 28, 2, 28, 1)
(2, 28, 2, 28, 2)
(2, 14, 2, 28, 3)
(2, 14, 2, 28, 3)