將自己的資料集製作成TFRecord格式教程
阿新 • • 發佈:2020-02-17
在使用TensorFlow訓練神經網路時,首先面臨的問題是:網路的輸入
此篇文章,教大家將自己的資料集製作成TFRecord格式,feed進網路,除了TFRecord格式,TensorFlow也支援其他格
式的資料,此處就不再介紹了。建議大家使用TFRecord格式,在後面可以通過api進行多執行緒的讀取檔案佇列。
1. 原本的資料集
此時,我有兩類圖片,分別是xiansu100,xiansu60,每一類中有10張圖片。
2.製作成TFRecord格式
tfrecord會根據你選擇輸入檔案的類,自動給每一類打上同樣的標籤。如在本例中,只有0,1 兩類,想知道資料夾名與label關係的,可以自己儲存起來。
#生成整數型的屬性 def _int64_feature(value): return tf.train.Feature(int64_list = tf.train.Int64List(value = [value])) #生成字串型別的屬性 def _bytes_feature(value): return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value])) #製作TFRecord格式 def createTFRecord(filename,mapfile): class_map = {} data_dir = '/home/wc/DataSet/traffic/testTFRecord/' classes = {'xiansu60','xiansu100'} #輸出TFRecord檔案的地址 writer = tf.python_io.TFRecordWriter(filename) for index,name in enumerate(classes): class_path=data_dir+name+'/' class_map[index] = name for img_name in os.listdir(class_path): img_path = class_path + img_name #每個圖片的地址 img = Image.open(img_path) img= img.resize((224,224)) img_raw = img.tobytes() #將圖片轉化成二進位制格式 example = tf.train.Example(features = tf.train.Features(feature = { 'label':_int64_feature(index),'image_raw': _bytes_feature(img_raw) })) writer.write(example.SerializeToString()) writer.close() txtfile = open(mapfile,'w+') for key in class_map.keys(): txtfile.writelines(str(key)+":"+class_map[key]+"\n") txtfile.close()
此段程式碼,執行完後會產生生成的.tfrecord檔案。
3. 讀取TFRecord的資料,進行解析,此時使用了檔案佇列以及多執行緒
#讀取train.tfrecord中的資料 def read_and_decode(filename): #建立一個reader來讀取TFRecord檔案中的樣例 reader = tf.TFRecordReader() #建立一個佇列來維護輸入檔案列表 filename_queue = tf.train.string_input_producer([filename],shuffle=False,num_epochs = 1) #從檔案中讀出一個樣例,也可以使用read_up_to一次讀取多個樣例 _,serialized_example = reader.read(filename_queue) # print _,serialized_example #解析讀入的一個樣例,如果需要解析多個,可以用parse_example features = tf.parse_single_example( serialized_example,features = {'label':tf.FixedLenFeature([],tf.int64),'image_raw': tf.FixedLenFeature([],tf.string),}) #將字串解析成影象對應的畫素陣列 img = tf.decode_raw(features['image_raw'],tf.uint8) img = tf.reshape(img,[224,224,3]) #reshape為128*128*3通道圖片 img = tf.image.per_image_standardization(img) labels = tf.cast(features['label'],tf.int32) return img,labels
4. 將圖片幾個一打包,形成batch
def createBatch(filename,batchsize): images,labels = read_and_decode(filename) min_after_dequeue = 10 capacity = min_after_dequeue + 3 * batchsize image_batch,label_batch = tf.train.shuffle_batch([images,labels],batch_size=batchsize,capacity=capacity,min_after_dequeue=min_after_dequeue ) label_batch = tf.one_hot(label_batch,depth=2) return image_batch,label_batch
5.主函式
if __name__ =="__main__": #訓練圖片兩張為一個batch,進行訓練,測試圖片一起進行測試 mapfile = "/home/wc/DataSet/traffic/testTFRecord/classmap.txt" train_filename = "/home/wc/DataSet/traffic/testTFRecord/train.tfrecords" # createTFRecord(train_filename,mapfile) test_filename = "/home/wc/DataSet/traffic/testTFRecord/test.tfrecords" # createTFRecord(test_filename,mapfile) image_batch,label_batch = createBatch(filename = train_filename,batchsize = 2) test_images,test_labels = createBatch(filename = test_filename,batchsize = 20) with tf.Session() as sess: initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()) sess.run(initop) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess = sess,coord = coord) try: step = 0 while 1: _image_batch,_label_batch = sess.run([image_batch,label_batch]) step += 1 print step print (_label_batch) except tf.errors.OutOfRangeError: print (" trainData done!") try: step = 0 while 1: _test_images,_test_labels = sess.run([test_images,test_labels]) step += 1 print step # print _image_batch.shape print (_test_labels) except tf.errors.OutOfRangeError: print (" TEST done!") coord.request_stop() coord.join(threads)
此時,生成的batch,就可以feed進網路了。
以上這篇將自己的資料集製作成TFRecord格式教程就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。