1. 程式人生 > >tf 數據讀取

tf 數據讀取

local for 允許 inpu 規則 any image join 轉換

tf.train.batch(
    tensors,
    batch_size,
    num_threads=1,
    capacity=32,
    enqueue_many=False,
    shapes=None,
    dynamic_pad=False,
    allow_smaller_final_batch=False,
    shared_name=None,
    name=None
)

  

  • tensors:排列的張量或詞典。
  • batch_size:從隊列中提取新的批量大小。
  • num_threads:排隊的線程數量tensors如果批次是不確定的
    num_threads > 1
  • capacity:一個整數。隊列中元素的最大數量。
  • enqueue_many:每張張量是否tensors都是一個例子。
  • shapes:(可選)每個示例的形狀。默認為推斷的形狀tensors
  • dynamic_pad:布爾值。在輸入形狀中允許可變尺寸。給定的尺寸在出列時填充,以便批次內的張量具有相同的形狀。
  • allow_smaller_final_batch:(可選)布爾值。如果True,如果隊列中剩余物品不足,則允許最終批次更小。
  • shared_name: (可選的)。如果設置,該隊列將在多個會話中以給定名稱共享。
  • name:(可選)操作的名稱。

tf.train.slice_input_producer(
    tensor_list,
    num_epochs=None,
    shuffle=True,
    seed=None,
    capacity=32,
    shared_name=None,
    name=None
)

  

  • tensor_listTensor對象列表每一個Tensortensor_list必須在第一維中具有相同的尺寸。
  • num_epochs:一個整數(可選)。如果指定,則會在生成錯誤之前生成slice_input_producer 每個切片num_epochs時間OutOfRange如果沒有指定,slice_input_producer可以循環切片無限次數。
  • shuffle:布爾值。如果為真,整數在每個時期內隨機洗牌。
  • seed:一個整數(可選)。種子使用,如果洗牌==真。
  • capacity:一個整數。設置隊列容量。
  • shared_name: (可選的)。如果設置,該隊列將在多個會話中以給定名稱共享。
  • name:操作的名稱(可選)。

核心步驟:

  1. 調用 tf.train.slice_input_producer,從 本地文件裏抽取tensor,準備放入Filename Queue(文件名隊列)中;
  2. 調用 tf.train.batch,從文件名隊列中提取tensor,使用單個或多個線程,準備放入文件隊列;
  3. 調用 tf.train.Coordinator() 來創建一個線程協調器,用來管理之後在Session中啟動的所有線程;
  4. 調用tf.train.start_queue_runners, 啟動入隊線程,由多個或單個線程,按照設定規則,把文件讀入Filename Queue中。函數返回線程ID的列表,一般情況下,系統有多少個核,就會啟動多少個入隊線程(入隊具體使用多少個線程在tf.train.batch中定義);
  5. 文件從 Filename Queue中讀入內存隊列的操作不用手動執行,由tf自動完成;
  6. 調用sess.run 來啟動數據出列和執行計算;
  7. 使用 coord.should_stop()來查詢是否應該終止所有線程,當文件隊列(queue)中的所有文件都已經讀取出列的時候,會拋出一個 OutofRangeError 的異常,這時候就應該停止Sesson中的所有線程了;
  8. 使用coord.request_stop()來發出終止所有線程的命令,使用coord.join(threads)把線程加入主線程,等待threads結束。

Queue和Coordinator操作事例:

import tensorflow as tf
import numpy as np

# 樣本個數
sample_num=5
# 設置叠代次數
epoch_num = 2
# 設置一個批次中包含樣本個數
batch_size = 3
# 計算每一輪epoch中含有的batch個數
batch_total = int(sample_num/batch_size)+1

# 生成4個數據和標簽
def generate_data(sample_num=sample_num):
    labels = np.asarray(range(0, sample_num))
    images = np.random.random([sample_num, 224, 224, 3])
    print(‘image size {},label size :{}‘.format(images.shape, labels.shape))
    return images,labels

def get_batch_data(batch_size=batch_size):
    images, label = generate_data()
    # 數據類型轉換為tf.float32
    images = tf.cast(images, tf.float32)
    label = tf.cast(label, tf.int32)

    #從tensor列表中按順序或隨機抽取一個tensor準備放入文件名稱隊列
    input_queue = tf.train.slice_input_producer([images, label], num_epochs=epoch_num, shuffle=False)

    #從文件名稱隊列中讀取文件準備放入文件隊列
    image_batch, label_batch = tf.train.batch(input_queue, batch_size=batch_size, num_threads=2, capacity=64, allow_smaller_final_batch=False)
    return image_batch, label_batch

image_batch, label_batch = get_batch_data(batch_size=batch_size)


with tf.Session() as sess:

    # 先執行初始化工作
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # 開啟一個協調器
    coord = tf.train.Coordinator()
    # 使用start_queue_runners 啟動隊列填充
    threads = tf.train.start_queue_runners(sess, coord)

    try:
        while not coord.should_stop():
            print (‘************‘)
            # 獲取每一個batch中batch_size個樣本和標簽
            image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
            print(image_batch_v.shape, label_batch_v)
    except tf.errors.OutOfRangeError:  #如果讀取到文件隊列末尾會拋出此異常
        print("done! now lets kill all the threads……")
    finally:
        # 協調器coord發出所有線程終止信號
        coord.request_stop()
        print(‘all threads are asked to stop!‘)
    coord.join(threads) #把開啟的線程加入主線程,等待threads結束
    print(‘all threads are stopped!‘)

  

輸出:

************
((3, 224, 224, 3), array([0, 1, 2], dtype=int32))
************
((3, 224, 224, 3), array([3, 4, 0], dtype=int32))
************
((3, 224, 224, 3), array([1, 2, 3], dtype=int32))
************
done! now lets kill all the threads……
all threads are asked to stop!
all threads are stopped!

  

以上程序在 tf.train.slice_input_producer 函數中設置了 num_epochs 的數量, 所以在文件隊列末尾有結束標誌,讀到這個結束標誌的時候拋出 OutofRangeError 異常,就可以結束各個線程了。

如果不設置 num_epochs 的數量,則文件隊列是無限循環的,沒有結束標誌,程序會一直執行下去。

tf 數據讀取