FastRCNN 訓練自己資料集(二)——修改讀寫介面
這裡樓主講解了如何修改Fast RCNN訓練自己的資料集,首先請確保你已經安裝好了Fast RCNN的環境,具體的編配編制操作請參考我的上一篇文章。首先可以看到fast rcnn的工程目錄下有個Lib目錄
這裡下面存在3個目錄分別是:
- datasets
- fast_rcnn
- roi_data_layer
- utils
在這裡修改讀寫資料的介面主要是datasets目錄下,fast_rcnn下面主要存放的是Python的訓練和測試指令碼,以及訓練的配置檔案,roi_data_layer下面存放的主要是一些ROI處理操作,utils下面存放的是一些通用操作比如非極大值nms,以及計算bounding box的重疊率等常用功能
1.構建自己的IMDB子類
1.1檔案概述
可有看到datasets目錄下主要有三個檔案,分別是
- factory.py
- imdb.py
- pascal_voc.py
factory.py 學過設計模式的應該知道這是個工廠類,用類生成imdb類並且返回資料庫共網路訓練和測試使用
imdb.py 這裡是資料庫讀寫類的基類,分裝了許多db的操作,但是具體的一些檔案讀寫需要繼承繼續讀寫
pascal_voc.py Ross在這裡用pascal_voc.py這個類來操作
1.2 讀取檔案函式分析
接下來我來介紹一下pasca_voc.py這個檔案,我們主要是基於這個檔案進行修改,裡面有幾個重要的函式需要修改
- def init(self, image_set, year, devkit_path=None)
這個是初始化函式,它對應著的是pascal_voc的資料集訪問格式,其實我們將其介面修改的更簡單一點 - def image_path_at(self, i)
根據第i個影象樣本返回其對應的path,其呼叫了image_path_from_index(self, index)作為其具體實現 - def image_path_from_index(self, index)
實現了 image_path的具體功能 - def _load_image_set_index(self)
載入了樣本的list檔案 - def _get_default_path(self)
獲得資料集地址 - def gt_roidb(self)
讀取並返回ground_truth的db - def selective_search_roidb
讀取並返回ROI的db - def _load_selective_search_roidb(self, gt_roidb)
載入預選框的檔案 - def selective_search_IJCV_roidb(self)
在這裡呼叫讀取Ground_truth和ROI db並將db合併 - def _load_selective_search_IJCV_roidb(self, gt_roidb)
這裡是專門讀取作者在IJCV上用的dataset - def _load_pascal_annotation(self, index)
這個函式是讀取gt的具體實現 - def _write_voc_results_file(self, all_boxes)
voc的檢測結果寫入到檔案 - def _do_matlab_eval(self, comp_id, output_dir='output')
根據matlab的evluation介面來做結果的分析 - def evaluate_detections
其呼叫了_do_matlab_eval - def competition_mode
設定competitoin_mode,加了一些噪點
1.3訓練資料集格式
在我的檢測任務裡,我主要是從道路卡口資料中檢測車,因此我這裡只有background 和car兩類物體,為了操作方便,我不像pascal_voc資料集裡面一樣每個影象用一個xml來標註多類,先說一下我的資料格式
這裡是所有樣本的影象列表
我的GroundTruth資料的格式,第一個為影象路徑,之後1代表目標物的個數, 後面的座標代表左上右下的座標,座標的位置從1開始
這裡我要特別提醒一下大家,一定要注意座標格式,一定要注意座標格式,一定要注意座標格式,重要的事情說三遍!!!,要不然你會範很多錯誤都會是因為座標不一致引起的報錯
1.4修改讀取介面
這裡是原始的pascal_voc的init函式,在這裡,由於我們自己的資料集往往比voc的資料集要更簡單的一些,在作者額程式碼裡面用了很多的路徑拼接,我們不用去迎合他的格式,將這些操作簡單化即可,在這裡我會一一列舉每個我修改過的函式。這裡按照檔案中的順序排列。
原始初始化函式:
def __init__(self, image_set, year, devkit_path=None):
datasets.imdb.__init__(self, 'voc_' + year + '_' + image_set)
self._year = year
self._image_set = image_set
self._devkit_path = self._get_default_path() if devkit_path is None \
else devkit_path
self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
self._classes = ('__background__', # always index 0
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
self._image_ext = '.jpg'
self._image_index = self._load_image_set_index()
# Default to roidb handler
self._roidb_handler = self.selective_search_roidb
# PASCAL specific config options
self.config = {'cleanup' : True,
'use_salt' : True,
'top_k' : 2000}
assert os.path.exists(self._devkit_path), \
'VOCdevkit path does not exist: {}'.format(self._devkit_path)
assert os.path.exists(self._data_path), \
'Path does not exist: {}'.format(self._data_path)
修改後的初始化函式:
def __init__(self, image_set, devkit_path=None):
datasets.imdb.__init__(self, image_set)#imageset 為train test
self._image_set = image_set
self._devkit_path = devkit_path
self._data_path = os.path.join(self._devkit_path)
self._classes = ('__background__','car')#包含的類
self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))#構成字典{'__background__':'0','car':'1'}
self._image_index = self._load_image_set_index('ImageList_Version_S_AddData.txt')#新增檔案列表
# Default to roidb handler
self._roidb_handler = self.selective_search_roidb
# PASCAL specific config options
self.config = {'cleanup' : True,
'use_salt' : True,
'top_k' : 2000}
assert os.path.exists(self._devkit_path), \
'VOCdevkit path does not exist: {}'.format(self._devkit_path)
assert os.path.exists(self._data_path), \
'Path does not exist: {}'.format(self._data_path)
原始的image_path_from_index:
def image_path_from_index(self, index):
"""
Construct an image path from the image's "index" identifier.
"""
image_path = os.path.join(self._data_path, 'JPEGImages',
index + self._image_ext)
assert os.path.exists(image_path), \
'Path does not exist: {}'.format(image_path)
return image_path
修改後的image_path_from_index:
def image_path_from_index(self, index):#根據_image_index獲取影象路徑
"""
Construct an image path from the image's "index" identifier.
"""
image_path = os.path.join(self._data_path, index)
assert os.path.exists(image_path), \
'Path does not exist: {}'.format(image_path)
return image_path
原始的 _load_image_set_index:
def _load_image_set_index(self):
"""
Load the indexes listed in this dataset's image set file.
"""
# Example path to image set file:
# self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',
self._image_set + '.txt')
assert os.path.exists(image_set_file), \
'Path does not exist: {}'.format(image_set_file)
with open(image_set_file) as f:
image_index = [x.strip() for x in f.readlines()]
return image_index
修改後的 _load_image_set_index:
def _load_image_set_index(self, imagelist):#已經修改
"""
Load the indexes listed in this dataset's image set file.
"""
# Example path to image set file:
# self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
#/home/chenjie/KakouTrainForFRCNN_1/DataSet/KakouTrainFRCNN_ImageList.txt
image_set_file = os.path.join(self._data_path, imagelist)# load ImageList that only contain ImageFileName
assert os.path.exists(image_set_file), \
'Path does not exist: {}'.format(image_set_file)
with open(image_set_file) as f:
image_index = [x.strip() for x in f.readlines()]
return image_index
函式 _get_default_path,我直接刪除了
原始的gt_roidb:
def gt_roidb(self):
"""
Return the database of ground-truth regions of interest.
This function loads/saves from/to a cache file to speed up future calls.
"""
cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} gt roidb loaded from {}'.format(self.name, cache_file)
return roidb
gt_roidb = [self._load_pascal_annotation(index)
for index in self.image_index]
with open(cache_file, 'wb') as fid:
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote gt roidb to {}'.format(cache_file)
return gt_roidb
修改後的gt_roidb:
def gt_roidb(self):
"""
Return the database of ground-truth regions of interest.
This function loads/saves from/to a cache file to speed up future calls.
"""
cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
if os.path.exists(cache_file):#若存在cache file則直接從cache file中讀取
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} gt roidb loaded from {}'.format(self.name, cache_file)
return roidb
gt_roidb = self._load_annotation() #已經修改,直接讀入整個GT檔案
with open(cache_file, 'wb') as fid:
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote gt roidb to {}'.format(cache_file)
return gt_roidb
原始的selective_search_roidb(self):
def selective_search_roidb(self):
"""
Return the database of selective search regions of interest.
Ground-truth ROIs are also included.
This function loads/saves from/to a cache file to speed up future calls.
"""
cache_file = os.path.join(self.cache_path,
self.name + '_selective_search_roidb.pkl')
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} ss roidb loaded from {}'.format(self.name, cache_file)
return roidb
if int(self._year) == 2007 or self._image_set != 'test':
gt_roidb = self.gt_roidb()
ss_roidb = self._load_selective_search_roidb(gt_roidb)
roidb = datasets.imdb.merge_roidbs(gt_roidb, ss_roidb)
else:
roidb = self._load_selective_search_roidb(None)
with open(cache_file, 'wb') as fid:
cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote ss roidb to {}'.format(cache_file)
return roidb
修改後的selective_search_roidb(self):
這裡有個pkl檔案我需要特別說明一下,如果你再次訓練的時候修改了資料庫,比如新增或者刪除了一些樣本,但是你的資料庫名字函式原來那個,比如我這裡訓練的資料庫叫KakouTrain,必須要在data/cache/目錄下把資料庫的快取檔案.pkl給刪除掉,否則其不會重新讀取相應的資料庫,而是直接從之前讀入然後快取的pkl檔案中讀取進來,這樣修改的資料庫並沒有進入網路,而是載入了老版本的資料。
def selective_search_roidb(self):#已經修改
"""
Return the database of selective search regions of interest.
Ground-truth ROIs are also included.
This function loads/saves from/to a cache file to speed up future calls.
"""
cache_file = os.path.join(self.cache_path,self.name + '_selective_search_roidb.pkl')
if os.path.exists(cache_file): #若存在cache_file則讀取相對應的.pkl檔案
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} ss roidb loaded from {}'.format(self.name, cache_file)
return roidb
if self._image_set !='KakouTest':
gt_roidb = self.gt_roidb()
ss_roidb = self._load_selective_search_roidb(gt_roidb)
roidb = datasets.imdb.merge_roidbs(gt_roidb, ss_roidb)
else:
roidb = self._load_selective_search_roidb(None)
with open(cache_file, 'wb') as fid:
cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote ss roidb to {}'.format(cache_file)
return roidb
原始的_load_selective_search_roidb(self, gt_roidb):
def _load_selective_search_roidb(self, gt_roidb):
filename = os.path.abspath(os.path.join(self.cache_path, '..',
'selective_search_data',
self.name + '.mat'))
assert os.path.exists(filename), \
'Selective search data not found at: {}'.format(filename)
raw_data = sio.loadmat(filename)['boxes'].ravel()
box_list = []
for i in xrange(raw_data.shape[0]):
box_list.append(raw_data[i][:, (1, 0, 3, 2)] - 1)
return self.create_roidb_from_box_list(box_list, gt_roidb)
修改後的_load_selective_search_roidb(self, gt_roidb):
這裡原作者用的是Selective_search,但是我用的是EdgeBox的方法來提取Mat,我沒有修改函式名,只是把輸入的Mat檔案給替換了,Edgebox實際的效果比selective_search要好,速度也要更快,具體的EdgeBox程式碼大家可以在Ross的tutorial中看到地址。
注意,這裡非常關鍵!!!!!,由於Selective_Search中的OP返回的坐���順序需要調整,並不是左上右下的順序,可以看到在下面box_list.append()中有一個(1,0,3,2)的操作,不管你用哪種OP方法,輸入的座標都應該是x1 y1 x2 y2,不要弄成w h 那種格式,也不要調換順序。座標-1,預設座標從0開始,樓主提醒各位,一定要非常注意座標順序,大小,邊界,格式問題,否則你會被錯誤折騰死的!!!
def _load_selective_search_roidb(self, gt_roidb):#已經修改
#filename = os.path.abspath(os.path.join(self.cache_path, '..','selective_search_data',self.name + '.mat'))
filename = os.path.join(self._data_path, 'EdgeBox_Version_S_AddData.mat')#這裡輸入相對應的預選框檔案路徑
assert os.path.exists(filename), \
'Selective search data not found at: {}'.format(filename)
raw_data = sio.loadmat(filename)['boxes'].ravel()
box_list = []
for i in xrange(raw_data.shape[0]):
#box_list.append(raw_data[i][:,(1, 0, 3, 2)] - 1)#原來的Psacalvoc調換了列,我這裡box的順序是x1 ,y1,x2,y2 由EdgeBox格式為x1,y1,w,h經過修改
box_list.append(raw_data[i][:,:] -1)
return self.create_roidb_from_box_list(box_list, gt_roidb)
原始的_load_selective_search_IJCV_roidb,我沒用這個資料集,因此不修改這個函式
原始的_load_pascal_annotation(self, index):
def _load_pascal_annotation(self, index):
"""
Load image and bounding boxes info from XML file in the PASCAL VOC
format.
"""
filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
# print 'Loading: {}'.format(filename)
def get_data_from_tag(node, tag):
return node.getElementsByTagName(tag)[0].childNodes[0].data
with open(filename) as f:
data = minidom.parseString(f.read())
objs = data.getElementsByTagName('object')
num_objs = len(objs)
boxes = np.zeros((num_objs, 4), dtype=np.uint16)
gt_classes = np.zeros((num_objs), dtype=np.int32)
overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
# Load object bounding boxes into a data frame.
for ix, obj in enumerate(objs):
# Make pixel indexes 0-based
x1 = float(get_data_from_tag(obj, 'xmin')) - 1
y1 = float(get_data_from_tag(obj, 'ymin')) - 1
x2 = float(get_data_from_tag(obj, 'xmax')) - 1
y2 = float(get_data_from_tag(obj, 'ymax')) - 1
cls = self._class_to_ind[
str(get_data_from_tag(obj, "name")).lower().strip()]
boxes[ix, :] = [x1, y1, x2, y2]
gt_classes[ix] = cls
overlaps[ix, cls] = 1.0
overlaps = scipy.sparse.csr_matrix(overlaps)
return {'boxes' : boxes,
'gt_classes': gt_classes,
'gt_overlaps' : overlaps,
'flipped' : False}
修改後的_load_pascal_annotation(self, index):
def _load_annotation(self):
"""
Load image and bounding boxes info from annotation
format.
"""
#,此函式作用讀入GT檔案,我的檔案的格式 CarTrainingDataForFRCNN_1\Images\2015011100035366101A000131.jpg 1 147 65 443 361
gt_roidb = []
annotationfile = os.path.join(self._data_path, 'ImageList_Version_S_GT_AddData.txt')
f = open(annotationfile)
split_line = f.readline().strip().split()
num = 1
while(split_line):
num_objs = int(split_line[1])
boxes = np.zeros((num_objs, 4), dtype=np.uint16)
gt_classes = np.zeros((num_objs), dtype=np.int32)
overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
for i in range(num_objs):
x1 = float( split_line[2 + i * 4])
y1 = float (split_line[3 + i * 4])
x2 = float (split_line[4 + i * 4])
y2 = float (split_line[5 + i * 4])
cls = self._class_to_ind['car']
boxes[i,:] = [x1, y1, x2, y2]
gt_classes[i] = cls
overlaps[i,cls] = 1.0
overlaps = scipy.sparse.csr_matrix(overlaps)
gt_roidb.append({'boxes' : boxes, 'gt_classes': gt_classes, 'gt_overlaps' : overlaps, 'flipped' : False})
split_line = f.readline().strip().split()
f.close()
return gt_roidb
之後的這幾個函式我都沒有修改,檢測結果,我是修改了demo.py這個檔案,直接生成txt檔案,然後用python OpenCV直接視覺化,沒有用著裡面的介面,感覺太麻煩了,先怎麼方便怎麼來
- _write_voc_results_file(self, all_boxes)
- _do_matlab_eval(self, comp_id, output_dir='output')
- evaluate_detections(self, all_boxes, output_dir)
- competition_mode(self, on)
記得在最後的__main__下面也修改相應的路徑
d = datasets.pascal_voc('trainval', '2007')
改成
d = datasets.kakou('KakouTrain', '/home/chenjie/KakouTrainForFRCNN_1')
並且同時在檔案的開頭import 裡面也做修改
import datasets.pascal_voc
改成
import datasets.kakou
OK,在這裡我們已經完成了整個的讀取介面的改寫,主要是將GT和預選框Mat檔案讀取並返回
2.修改factory.py
當網路訓練時會呼叫factory裡面的get方法獲得相應的imdb,
首先在檔案頭import 把pascal_voc改成kakou
在這個檔案作者生成了多個數據庫的路徑,我們自己資料庫只要給定根路徑即可,修改主要有以下4個
- 因此將裡面的def _selective_search_IJCV_top_k函式整個註釋掉
- 函式之後有兩個多級的for迴圈,也將其註釋
- 直接定義imageset和devkit
- 修改get_imdb函式
原始的factory.py:
__sets = {}
import datasets.pascal_voc
import numpy as np
def _selective_search_IJCV_top_k(split, year, top_k):
"""Return an imdb that uses the top k proposals from the selective search
IJCV code.
"""
imdb = datasets.pascal_voc(split, year)
imdb.roidb_handler = imdb.selective_search_IJCV_roidb
imdb.config['top_k'] = top_k
return imdb
# Set up voc_<year>_<split> using selective search "fast" mode
for year in ['2007', '2012']:
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year:
datasets.pascal_voc(split, year))
# Set up voc_<year>_<split>_top_<k> using selective search "quality" mode
# but only returning the first k boxes
for top_k in np.arange(1000, 11000, 1000):
for year in ['2007', '2012']:
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}_top_{:d}'.format(year, split, top_k)
__sets[name] = (lambda split=split, year=year, top_k=top_k:
_selective_search_IJCV_top_k(split, year, top_k))
def get_imdb(name):
"""Get an imdb (image database) by name."""
if