1. 程式人生 > >tensorflow object detection api訓練自己的資料集

tensorflow object detection api訓練自己的資料集

tensorflow object detection API 創造一些精確的機器學習模型用於定位和識別一幅影象裡的多元目標仍然是一個計算機視覺領域的核心挑戰。tensorflow object detection API是一個開源的基於tensorflow的框架,使得建立,訓練以及應用目標檢測模型變得簡單。在谷歌我們已經確定發現這個程式碼對我們的計算機視覺研究需要很有用,我們希望這個對你也會很有用。 1. 安裝tensorflow以及下載object detection api 安裝tensorflow: 對於CPU版本:pip install tensorflow 對於GPU版本:pip install tensorflow-gpu 升級tensorflow到最新版1.4.0:pip install --upgrade tensorflow-gpu   安裝必須庫: sudo pip install pillow sudo pip install lxml sudo pip install jupyter sudo pip install matplotlib protobuf編譯:在tensorflow/models/research/目錄下 protoc object_detection/protos/*.proto --python_out=. 新增pythonpath,在tensorflow/models/research/目錄下 export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim 測試安裝: python object_detection/builders/model_builder_test.py 下載object detection api: 2.執行演示檔案:object_detection_tutorial.ipynb
2.訓練資料集準備 在model下新建資料夾dataset,將我使用的pascal voc格式資料集(VOC3000)轉換為TFRecord格式,並存放在dataset資料夾下: 將create_pascal_tf_record.py檔案複製到dataset資料夾下: (1)修改第55行:YEARS = ['VOC2007', 'VOC2012','VOC3000', 'merged'] (2)修改第58行:def dict_to_tf_example(data, 改為def dict_to_tf_example(year,data, (3)修改第84行:img_path = os.path.join(data['folder'], image_subdirectory, data['filename']) 改為img_path = os.path.join(year,image_subdirectory, data['filename']) (4)修改第152行:years = ['VOC2007', 'VOC2012'] 改為years = ['VOC2007', 'VOC2012','VOC3000'] (5)修改第163行:examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main', 'aeroplane_' + FLAGS.set + '.txt') 改為 examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main', FLAGS.set + '.txt') (6)修改第175行:tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict, FLAGS.ignore_difficult_instances) 改為tf_example = dict_to_tf_example(year, data, FLAGS.data_dir, label_map_dict, FLAGS.ignore_difficult_instances) 以上涉及到路徑需要根據自己資料集調整。 執行以下命令,就可以得到用於訓練和驗證的tf_record檔案: python data/create_pascal3000_tf_record.py --data_dir=/data/models/research/object_detection/dataset/VOCdevkit --label_map_path=/data/models/research/object_detection/dataset/pascal_label_map.pbtxt --year=VOC3000 --set=train --output_path=/data/models/research/object_detection/dataset/pascal_train.record python data/create_pascal3000_tf_record.py --data_dir=/data/models/research/object_detection/dataset/VOCdevkit --label_map_path=/data/models/research/object_detection/dataset/pascal_label_map.pbtxt --year=VOC3000 --set=val --output_path=/data/models/research/object_detection/dataset/pascal_val.record 3.解壓SSDMobilenet模型(下載API的時候已經下載好了) tar -xvf ssd_mobilenet_v1_coco_2017_11_08.tar.gz  得到如下檔案:

將資料夾裡面的model.ckpt.*的三個檔案copy到dataset資料夾。 4.修改config檔案。 將檔案object_detection/samples/configs/ssd_mobilenet_v1_pets.config複製到dataset. 修改: (1)num_classes修改為自己的類別數目,我的是10 (2)修改路徑。(5處) fine_tune_checkpoint: "/data/models/research/object_detection/dataset/model.ckpt" input_path: "/data/models/research/object_detection/dataset/pascal_train.record" label_map_path: "/data/models/research/object_detection/dataset/pascal_label_map.pbtxt" input_path: "/data/models/research/object_detection/dataset/pascal_val.record" label_map_path: "/data/models/research/object_detection/dataset/pascal_label_map.pbtxt" 儲存config檔案,重新命名為ssd_mobilenet_v1_pascal.config
。我的dataset資料夾如圖所示。
5.開始訓練(這裡我換用了另一個模型faster_rcnn_inception_resnet) python train.py --logtostderr --train_dir=/home/amax/guo/models/object_detection/dataset/output --pipeline_config_path=/home/amax/guo/models/object_detection/dataset/faster_rcnn_inception_resnet/faster_rcnn_inception_resnet_v2_atrous_pets.config
6.評估模型 在dataset資料夾下新建evaluation資料夾 python eval.py --logtostderr --checkpoint_dir=/home/amax/guo/models/object_detection/dataset/output --pipeline_config_path=/home/amax/guo/models/object_detection/dataset/faster_rcnn_inception_resnet/faster_rcnn_inception_resnet_v2_atrous_pets.config --eval_dir=/home/amax/guo/models/object_detection/dataset/evaluation 報錯:ImportError: No module named nets 解決辦法:匯入slim模組 import sys sys.path.append('/data/models/research/slim')
7.檢視結果 tensorboard --logdir=/home/amax/guo/models/object_detection/dataset

8.生成可以被呼叫的模型 python object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path /home/amax/guo/models/object_detection/dataset/faster_rcnn_inception_resnet/faster_rcnn_inception_resnet_v2_atrous_pets.config --trained_checkpoint_prefix /home/amax/guo/models/object_detection/dataset/output/model.ckpt-10000 --output_directory /home/amax/guo/models/object_detection/dataset/savedModel 生成的模型如圖所示:
9.呼叫生成的模型 修改object_detection_tutorial.py PATH_TO_CKPT ='/home/amax/guo/models/object_detection/dataset/savedModel/frozen_inference_graph.pb' PATH_TO_LABELS='/home/amax/guo/models/object_detection/dataset/pascal_label_map.pbtxt' NUM_CLASSES = 10 結果如下: