1. 程式人生 > 其它 >功能性模組:(7)檢測效能評估模組(precision,recall等)

功能性模組:(7)檢測效能評估模組(precision,recall等)

技術標籤:功能性模組演算法python深度學習pytorch神經網路

功能性模組:(7)檢測效能評估模組

一、模組介紹

其實每個演算法的好壞都是有對應的評估標準的,如果你和老闆說檢測演算法好或者不好,哈哈哈,那必然就是悲劇了。好或者不好是一個定性的說法,對於實際演算法來說,到底怎麼樣演算法算好?怎麼樣演算法算不好?這些應該是有個定量的標準。對於檢測來說,可能最常用的幾個評價指標就是precision(查準率,就是你檢測出來的目標有多少是真的目標),recall(查全率,就是實際的目標你的演算法能檢測出來多少),還有ap,map等。本篇部落格其實就是讓小夥伴們對自己的檢測模型心裡有一個底,換句話說這個模型你訓練出來到底咋樣?

二、程式碼實現

import numpy as np
import os

def voc_ap(rec, prec, use_07_metric=False):
    """Compute VOC AP given precision and recall. If use_07_metric is true, uses
    the VOC 07 11-point method (default:False).
    """
    if use_07_metric:
        # 11 point metric
        ap =
0. for t in np.arange(0., 1.1, 0.1): if np.sum(rec >= t) == 0: p = 0 else: p = np.max(prec[rec >= t]) ap = ap + p / 11. else: # correct AP calculation # first append sentinel values at the end mrec =
np.concatenate(([0.], rec, [1.])) mpre = np.concatenate(([0.], prec, [0.])) # compute the precision envelope for i in range(mpre.size - 1, 0, -1): mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) # to calculate area under PR curve, look for points # where X axis (recall) changes value i = np.where(mrec[1:] != mrec[:-1])[0] # and sum (\Delta recall) * prec ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) return ap def ComputeMAP(gt_root, predict_root, OVTHRESH=0.5): """ :param gt_root: 生成gt檔案的根目錄 :param predict_root: 演算法跑出的根目錄 :param overthresh: 設定的閾值 :return: """ # 獲取所有的檔案 files_gt = os.listdir(gt_root) files_pred = os.listdir(predict_root) files_gt.sort() # 這個變數的目的是什麼?儲存gt中真正的框的數量 npos = 0 class_recs = {} # 遍歷所有gt檔案 for file_gt in files_gt: img_name = os.path.splitext(os.path.basename(file_gt))[0] file_gt = os.path.join(gt_root, os.path.basename(file_gt)) print("*" * 80) print("img name is: ", img_name) print("gt file is: ", file_gt) # 處理gt檔案 with open(file_gt, 'r') as f: lines = f.readlines() splitlines = [x.strip().split(' ') for x in lines] bbox = np.array([[float(z) for z in x[:]] for x in splitlines]) print("bbox is: \n", bbox) det = [False] * len(bbox) npos = npos + len(bbox) class_recs[img_name] = {'bbox': bbox, 'det': det} print("*" * 80) print("Total npos is: ", npos) # 遍歷所有的檢測結果 img_ids = [] confidence = [] BB = [] for file_pred in files_pred: img_name = os.path.splitext(os.path.basename(file_pred))[0] file_pred = os.path.join(pred_root, os.path.basename(file_pred)) print("*" * 80) print("img_name is: ", img_name) print("pred file is: ", file_pred) with open(file_pred, 'r') as f: lines = f.readlines() splitlines = [x.strip().split(" ") for x in lines] confidence_p = [float(x[0]) for x in splitlines] bbox_p = [[float(z) for z in x[1:]] for x in splitlines] # 根據confidence_p的長度,複製對應的img_name的str,生成對應長度的list # ['20160220082030T28_H', '20160220082030T28_H', '20160220082030T28_H', '20160220082030T28_H'] img_ids.extend([img_name] * len(confidence_p)) confidence.extend(confidence_p) BB.extend(bbox_p) print(img_ids) print(confidence) print(BB) confidence = np.array(confidence) BB = np.array(BB) print("*" * 80) print("All files loaded!") # 按照confidence的降序進行排列 sorted_idx = np.argsort(-confidence) print("sorted idx is: ", sorted_idx) BB = BB[sorted_idx, :] img_ids = [img_ids[x] for x in sorted_idx] # 計算對應的TPs 和 FPs nd = len(img_ids) tp = np.zeros(nd) fp = np.zeros(nd) wrong_count = 0 for d in range(nd): print("We are now test: ", img_ids[d]) # 取出對應影象的gt R = class_recs[img_ids[d]] # 檢測的結果 bb = BB[d, :].astype(float) # 假設重疊面積初始為-inf ovmax = -np.inf BBGT = R['bbox'].astype(float) print("bb: \n ", bb) print("BBGT: \n", BBGT) print("BBGT size is: ", BBGT.size) if BBGT.size > 0: # 計算覆蓋的部分 ixmin = np.maximum(BBGT[:, 0], bb[0]) iymin = np.maximum(BBGT[:, 1], bb[1]) ixmax = np.minimum(BBGT[:, 2], bb[2]) iymax = np.minimum(BBGT[:, 3], bb[3]) iw = np.maximum(ixmax - ixmin + 1., 0.) ih = np.maximum(iymax - iymin + 1., 0.) # 計算交叉的面積 inters = iw * ih # 計算iou吧 uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) + (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0) - inters) overlaps = inters / uni ovmax = np.max(overlaps) jmax = np.argmax(overlaps) print("overlaps is: ", overlaps) print("ovmax is: ", ovmax) print("jmax is: ", jmax) if ovmax > OVTHRESH: # 如果檢測的這個標記還沒有啟用,預設是False if not R['det'][jmax]: tp[d] = 1. R['det'][jmax] = 1 else: fp[d] = 1. wrong_count += 1 else: fp[d] = 1. wrong_count += 1 np.set_printoptions(threshold=np.inf) # 計算 precision 和 recall fp = np.cumsum(fp) tp = np.cumsum(tp) print("fp is: ", fp) print("tp is: ", tp) # 召回率(查全率) rec = tp / float(npos) # 精確率(查準率) prec = tp / np.maximum(tp + fp, np.finfo(np.float).eps) ap = voc_ap(rec, prec, False) print("ap is: ", ap) print("*" * 80) print("RESULTS: \n") print("Total %d images, %d objects" % (len(files_gt), npos)) print("Detected Correct: %d, Wrong: %d, Miss: %d under IOU: %f" % (nd - wrong_count, wrong_count, npos - (nd - wrong_count), OVTHRESH)) print("Accuracy %f, Recall %f, Average Precision %f" % (float(nd - wrong_count) / (nd), float(nd - wrong_count) / (npos), ap)) # 記錄漏檢的檔案 f = open('./lost.txt', 'w') for k, v in class_recs.items(): if False in v['det']: f.write(str(k) + '.jpg' + '\n') f.close() if __name__ == "__main__": gt_root = './mini_test/gt/' pred_root = './mini_test/res/' ComputeMAP(gt_root, pred_root)

LZ就不詳細講程式碼了,註釋已經很詳細了,主要是你的gt應該是什麼樣子的呢?

  • 命名標準:img_name.txt
  • gt格式:
# x1 y1 x2 y2
965 209 1040 329 
  • res格式:
# score x1 y1 x2 y2
0.9999481 962 222 1043 331
0.9999091 635 251 747 412
0.9783503 1795 340 1836 402
0.57386667 1730 305 1748 337

這個是結果展示,程式碼中LZ為了清晰加了非常多的列印,誰讓雲端儲存不穩定呢,動不動圖片就被損壞了,哭唧唧。。。

在這裡插入圖片描述
ps:最近疫情反彈的厲害,誰能想到新冠肺炎居然堅持了一年,國外疫情也是指數性增長,這算是人類的災難,也許多年後在看現在,又會有不一樣的體會。珍惜當下,愛惜生命!