1. 程式人生 > >tensorflow 基本函式(不斷更新哦)

tensorflow 基本函式(不斷更新哦)

1.  tf.split(3, group, input)  # 拆分函式
    3 表示的是在第三個維度上, group表示拆分的次數, input 表示輸入的值

import tensorflow as tf
import numpy as np

x = [[1, 2], [3, 4]]
Y = tf.split(axis=1, num_or_size_splits=2, value=x)

sess = tf.Session()
for y in Y:
    print(sess.run(y))

 

2.  tf.concat(3, input) # 串接函式
    3 表示的是在第三個維度上, input表示的是輸入,輸入一般都是列表

import tensorflow as tf


x = [[1, 2], [3, 4]]
y = tf.concat(x, axis=0)

sess = tf.Session()
print(sess.run(y))

3. tf.squeeze(input, squeeze_dims=[1, 2]) # 表示的是去除列數為1的維度, squeeze_dim 指定維度

import tensorflow as tf
import numpy as np

x = [[1, 2]]
print(np.array(x).shape)
y = tf.squeeze(x, axis=[0])

sess 
= tf.Session() print(sess.run(y))

4. tf.less_equal(a, b)  a 可以是一個列表, b表示需要比較的數,如果比b大返回false,否者返回True

import tensorflow as tf
import numpy as np

raw_gt = [1, 2, 3, 4]

y = tf.where(tf.less_equal(raw_gt, 2))


sess = tf.Session()

print(sess.run(y))

5.tf.where(input)   # 返回是真的序號,通過tf.where找出小於等於2的數的序號

import tensorflow as tf
import numpy as np

raw_gt = [1, 2, 3, 4]


y = tf.where(tf.less_equal(raw_gt, 2))

sess = tf.Session()
print(sess.run(y))

6. tf.gather   # 根據序列號對資料進行取值,輸入的是input, index

import tensorflow as tf
import numpy as np

raw_gt = [3, 4, 5, 6]


y = tf.gather(raw_gt, [[0], [1]])

sess = tf.Session()
print(sess.run(y))