1. 程式人生 > >DeepLab V3+ 訓練自己的資料

DeepLab V3+ 訓練自己的資料

一、前提

官網程式碼:https://github.com/tensorflow/models/tree/master/research/deeplab 

1. 依賴

DeepLab依賴於以下庫:

  • Numpy
  • Pillow 1.0
  • tf Slim (which is included in the "tensorflow/models/research/" checkout)
  • Jupyter notebook
  • Matplotlib
  • Tensorflow1.6及以上

2. 將庫新增到PYTHONPATH

在本地執行時,tensorflow / models / research /和slim目錄應該附加到PYTHONPATH。 這可以通過在 tensorflow / models / research /路徑下執行以下命令來完成:

# From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

注意:每次啟用新終端,此命令都需要執行。 如果想避免手動執行,可以將其作為新行新增到〜/ .bashrc檔案的末尾。

3. 簡單測試

可以通過執行以下命令來測試是否已成功安裝 Tensorflow DeepLab:

執行 model_test.py 進行快速測試:

# From tensorflow/models/research/
python deeplab/model_test.py

在PASCAL VOC 2012資料集上快速執行整個程式碼:

# From tensorflow/models/research/deeplab
sh local_test.sh

local_tesr.sh 指令碼用於在PASCAL VOC 2012上執行本地測試。

之後在自己資料集上進行訓練等操作就可以參照 local_test.sh 來編輯指令。開啟指令碼看一下,發現它:

(1)執行了model_test.py

(2)執行了download_and_convert_voc2012.sh

(3)從model_zoo(http://download.tensorflow.org/models)下載了模型deeplabv3_pascal_train_aug

(4)執行了train.py

(5)執行了eval.py

(6)執行了vis.py

(7)執行了export_model.py

建議仔細閱讀上面提到的指令碼和程式,為以後訓練自己的資料提供參考。

錯誤1:

執行 model_test.py 進行快速測試時出錯:

參考 https://github.com/tensorflow/models/issues/5523

 將 model_test.py 中140行左右的:

self.assertListEqual(scales_to_model_results.keys(),

修改為:

self.assertListEqual(list(scales_to_model_results.keys()),

錯誤2:

測試程式需要執行eval.py,我在這裡出現了一個錯誤:

即:InvalidArgumentError (see above for traceback): assertion failed: [`predictions` out of bound] [Condition x < y did not hold element-wise:] [x (mean_iou/confusion_matrix/control_dependency_1:0) = ] [0 0 0...] [y (mean_iou/ToInt64_2:0) = ] [21]

參考https://github.com/tensorflow/models/issues/4203中trobr的說法:

修改 eval.py 中第145 行左右:

將:

metric_map = {}

metric_map[predictions_tag] = tf.metrics.mean_iou(

        predictions, labels, dataset.num_classes, weights=weights)

修改為:   也就是中間插入了幾行

 metric_map = {}

    # insert by trobr

    indices = tf.squeeze(tf.where(tf.less_equal(

        labels, dataset.num_classes - 1)), 1)

    labels = tf.cast(tf.gather(labels, indices), tf.int32)

    predictions = tf.gather(predictions, indices)

    # end of insert

    metric_map[predictions_tag] = tf.metrics.mean_iou(

        predictions, labels, dataset.num_classes, weights=weights)

二、資料準備 

參照VOC2012的檔案結構,把自己的資料和資料夾準備好。

參考download_and_convert_voc2012.sh進行資料轉化。

1. label圖修改(也許需要)

label圖應該是沒有色彩的,類別的畫素標記應該是0,1,2,3......

注意:不要把 ignore_label background 混淆,ignore_label是沒有做標註的,不在預測範圍內的,ignore_label是不參與計算loss的。我們在mask中將 ignore_label 的灰度值標記為255,而background 標記為0

如果是voc2012這種有colormap的標籤圖,可以利用remove_gt_colormap.py先去掉colormap:

# from research/deeplab/datasets
python remove_gt_colormap.py \
  --original_gt_folder="/path/SegmentationClass" \
  --output_dir="/path/SegmentationClassRaw"

其中, original_gt_folder是原始標籤圖資料夾,output_dir是要輸出的標籤圖資料夾的位置。

2. 資料轉換為tfrecord

# from research/deeplab/datasets
python build_voc2012_data.py \
  --image_folder="/path/JPEGImages" \
  --semantic_segmentation_folder="/path/SegmentationClassRaw" \
  --list_folder="/path/ImageSets/Segmentation" \
  --image_format="jpg" \
  --output_dir="/path/tfrecord"

其中,image_folder是jpg原圖資料夾,semantic_segmentation_folder是轉化後label圖資料夾,list_folder是train.txt、val.txt、trainval.txt所在的資料夾,output_dir是輸出資料存放的資料夾。

轉換後的資料儲存到tfrecord(tfrecord資料夾事先建好)。

三、訓練準備

1. 修改segmentation_dataset.py(註冊資料集)

(1)在這段程式碼註冊資料集,使我的資料集 voc_turbulent 擁有姓名:

_DATASETS_INFORMATION = {

    'cityscapes': _CITYSCAPES_INFORMATION,

    'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,

    'ade20k': _ADE20K_INFORMATION,

    'voc_turbulent': _VOC_TURBULENT_INFORMATION,

}

(2)參照程式碼中其他資料集形式,加入個人資料集描述配置:

訓練、檢測資料的數量修改好,類別數量也根據實際修改。

_VOC_TURBULENT_INFORMATION = DatasetDescriptor(

    splits_to_sizes={

        'train': 1413,

        'trainval': 2826,

        'val': 1413,

    },

    num_classes=21,

    ignore_label=255,

)

2. 修改train_utils.py

檔案修改如下:

exclude_list = ['global_step'] 

修改為:

exclude_list = ['global_step', 'logits'] 

作用是在使用預訓練權重時候,不載入該logit層。訓練自己的資料集時,此處進行修改。

四、訓練

模型從官網下載:https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md

python deeplab/train.py \
  --logtostderr \
  --train_split="train" \
  --model_variant="xception_65" \
  --dataset="voc_turbulent" \#前面註冊的資料集名稱
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --train_crop_size=513 \
  --train_crop_size=513 \
  --training_number_of_steps=90000  \
  --base_learning_rate=0.0001 \
  --num_clones=3 \#3塊顯示卡
  --train_batch_size=9 \#得是顯示卡數量的倍數哈
  --fine_tune_batch_norm=false \
  --initialize_last_layer=False \
  --last_layers_contain_logits_only=True \
  --tf_initial_checkpoint="/path/deeplabv3_pascal_train_aug/model.ckpt" \
  --train_logdir="/path/exp/train_on_train_set/train" \
  --dataset_dir="/path/tfrecord"

注意:

(1)學習率

 (2)batch size

(3)模型選擇及引數

(4)crop size 

(5)關於initialize_last_layer和last_layers_contain_logits_only

五、驗證 

python deeplab/eval.py \
  --logtostderr \
  --eval_split="val" \
  --model_variant="xception_65" \
  --dataset="voc_turbulent" \
  --num_clones=3 \#3塊顯示卡
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --eval_crop_size=513 \
  --eval_crop_size=513 \
  --checkpoint_dir="/path/exp/train_on_train_set/train" \
  --eval_logdir="/path/exp/train_on_train_set/eval" \
  --dataset_dir="/path/tfrecord" \
  --max_number_of_evaluations=1

結果不是很好: 

六、視覺化

python deeplab/vis.py \
  --logtostderr \
  --vis_split="val" \
  --model_variant="xception_65" \
  --dataset="voc_turbulent" \
  --num_clones=3 \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --vis_crop_size=513 \
  --vis_crop_size=513 \
  --checkpoint_dir="/xxx/exp/train_on_train_set/train" \
  --vis_logdir="/xxx/exp/train_on_train_set/vis" \
  --dataset_dir="/xxx/tfrecord" \
  --max_number_of_iterations=1

七、預測單張圖片 

在deeplab_demo.ipynb的基礎上做些修改,為方便使用,給出網盤連結,使用時修改路徑即可。

連結:https://pan.baidu.com/s/16iffY6WkOwjRezttAuulFQ 
提取碼:06fy 
效果如下: