功能性模組:(7)檢測效能評估模組(precision,recall等)
阿新 • • 發佈:2021-01-13
技術標籤:功能性模組演算法python深度學習pytorch神經網路
功能性模組:(7)檢測效能評估模組
一、模組介紹
其實每個演算法的好壞都是有對應的評估標準的,如果你和老闆說檢測演算法好或者不好,哈哈哈,那必然就是悲劇了。好或者不好是一個定性的說法,對於實際演算法來說,到底怎麼樣演算法算好?怎麼樣演算法算不好?這些應該是有個定量的標準。對於檢測來說,可能最常用的幾個評價指標就是precision(查準率,就是你檢測出來的目標有多少是真的目標),recall(查全率,就是實際的目標你的演算法能檢測出來多少),還有ap,map等。本篇部落格其實就是讓小夥伴們對自己的檢測模型心裡有一個底,換句話說這個模型你訓練出來到底咋樣?
二、程式碼實現
import numpy as np
import os
def voc_ap(rec, prec, use_07_metric=False):
"""Compute VOC AP given precision and recall. If use_07_metric is true, uses
the VOC 07 11-point method (default:False).
"""
if use_07_metric:
# 11 point metric
ap = 0.
for t in np.arange(0., 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.
else:
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))
# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap
def ComputeMAP(gt_root, predict_root, OVTHRESH=0.5):
"""
:param gt_root: 生成gt檔案的根目錄
:param predict_root: 演算法跑出的根目錄
:param overthresh: 設定的閾值
:return:
"""
# 獲取所有的檔案
files_gt = os.listdir(gt_root)
files_pred = os.listdir(predict_root)
files_gt.sort()
# 這個變數的目的是什麼?儲存gt中真正的框的數量
npos = 0
class_recs = {}
# 遍歷所有gt檔案
for file_gt in files_gt:
img_name = os.path.splitext(os.path.basename(file_gt))[0]
file_gt = os.path.join(gt_root, os.path.basename(file_gt))
print("*" * 80)
print("img name is: ", img_name)
print("gt file is: ", file_gt)
# 處理gt檔案
with open(file_gt, 'r') as f:
lines = f.readlines()
splitlines = [x.strip().split(' ') for x in lines]
bbox = np.array([[float(z) for z in x[:]] for x in splitlines])
print("bbox is: \n", bbox)
det = [False] * len(bbox)
npos = npos + len(bbox)
class_recs[img_name] = {'bbox': bbox, 'det': det}
print("*" * 80)
print("Total npos is: ", npos)
# 遍歷所有的檢測結果
img_ids = []
confidence = []
BB = []
for file_pred in files_pred:
img_name = os.path.splitext(os.path.basename(file_pred))[0]
file_pred = os.path.join(pred_root, os.path.basename(file_pred))
print("*" * 80)
print("img_name is: ", img_name)
print("pred file is: ", file_pred)
with open(file_pred, 'r') as f:
lines = f.readlines()
splitlines = [x.strip().split(" ") for x in lines]
confidence_p = [float(x[0]) for x in splitlines]
bbox_p = [[float(z) for z in x[1:]] for x in splitlines]
# 根據confidence_p的長度,複製對應的img_name的str,生成對應長度的list
# ['20160220082030T28_H', '20160220082030T28_H', '20160220082030T28_H', '20160220082030T28_H']
img_ids.extend([img_name] * len(confidence_p))
confidence.extend(confidence_p)
BB.extend(bbox_p)
print(img_ids)
print(confidence)
print(BB)
confidence = np.array(confidence)
BB = np.array(BB)
print("*" * 80)
print("All files loaded!")
# 按照confidence的降序進行排列
sorted_idx = np.argsort(-confidence)
print("sorted idx is: ", sorted_idx)
BB = BB[sorted_idx, :]
img_ids = [img_ids[x] for x in sorted_idx]
# 計算對應的TPs 和 FPs
nd = len(img_ids)
tp = np.zeros(nd)
fp = np.zeros(nd)
wrong_count = 0
for d in range(nd):
print("We are now test: ", img_ids[d])
# 取出對應影象的gt
R = class_recs[img_ids[d]]
# 檢測的結果
bb = BB[d, :].astype(float)
# 假設重疊面積初始為-inf
ovmax = -np.inf
BBGT = R['bbox'].astype(float)
print("bb: \n ", bb)
print("BBGT: \n", BBGT)
print("BBGT size is: ", BBGT.size)
if BBGT.size > 0:
# 計算覆蓋的部分
ixmin = np.maximum(BBGT[:, 0], bb[0])
iymin = np.maximum(BBGT[:, 1], bb[1])
ixmax = np.minimum(BBGT[:, 2], bb[2])
iymax = np.minimum(BBGT[:, 3], bb[3])
iw = np.maximum(ixmax - ixmin + 1., 0.)
ih = np.maximum(iymax - iymin + 1., 0.)
# 計算交叉的面積
inters = iw * ih
# 計算iou吧
uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.)
+ (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0)
- inters)
overlaps = inters / uni
ovmax = np.max(overlaps)
jmax = np.argmax(overlaps)
print("overlaps is: ", overlaps)
print("ovmax is: ", ovmax)
print("jmax is: ", jmax)
if ovmax > OVTHRESH:
# 如果檢測的這個標記還沒有啟用,預設是False
if not R['det'][jmax]:
tp[d] = 1.
R['det'][jmax] = 1
else:
fp[d] = 1.
wrong_count += 1
else:
fp[d] = 1.
wrong_count += 1
np.set_printoptions(threshold=np.inf)
# 計算 precision 和 recall
fp = np.cumsum(fp)
tp = np.cumsum(tp)
print("fp is: ", fp)
print("tp is: ", tp)
# 召回率(查全率)
rec = tp / float(npos)
# 精確率(查準率)
prec = tp / np.maximum(tp + fp, np.finfo(np.float).eps)
ap = voc_ap(rec, prec, False)
print("ap is: ", ap)
print("*" * 80)
print("RESULTS: \n")
print("Total %d images, %d objects" % (len(files_gt), npos))
print("Detected Correct: %d, Wrong: %d, Miss: %d under IOU: %f"
% (nd - wrong_count, wrong_count, npos - (nd - wrong_count), OVTHRESH))
print("Accuracy %f, Recall %f, Average Precision %f"
% (float(nd - wrong_count) / (nd), float(nd - wrong_count) / (npos), ap))
# 記錄漏檢的檔案
f = open('./lost.txt', 'w')
for k, v in class_recs.items():
if False in v['det']:
f.write(str(k) + '.jpg' + '\n')
f.close()
if __name__ == "__main__":
gt_root = './mini_test/gt/'
pred_root = './mini_test/res/'
ComputeMAP(gt_root, pred_root)
LZ就不詳細講程式碼了,註釋已經很詳細了,主要是你的gt應該是什麼樣子的呢?
- 命名標準:img_name.txt
- gt格式:
# x1 y1 x2 y2
965 209 1040 329
- res格式:
# score x1 y1 x2 y2
0.9999481 962 222 1043 331
0.9999091 635 251 747 412
0.9783503 1795 340 1836 402
0.57386667 1730 305 1748 337
這個是結果展示,程式碼中LZ為了清晰加了非常多的列印,誰讓雲端儲存不穩定呢,動不動圖片就被損壞了,哭唧唧。。。
ps:最近疫情反彈的厲害,誰能想到新冠肺炎居然堅持了一年,國外疫情也是指數性增長,這算是人類的災難,也許多年後在看現在,又會有不一樣的體會。珍惜當下,愛惜生命!