opencv-python(十三):DNN模組載入caffe訓練好的SSD模型
阿新 • • 發佈:2018-12-09
opencv越來越強大了,可以直接對訓練好的caffe、tensorflow等框架訓練好的模型進行載入,進而完成識別、檢測等任務。
opencv載入caffe訓練好的模型,採用readNetFromCaffe(arg1,arg2),第一個引數對應定義模型結構的prototxt檔案,第二個引數對應於訓練好的model,載入完之後,使用blobFromImage函式,將圖片轉換成blob格式,網路接收輸入資料後,通過forward()函式進行前向傳播,即可得到網路輸出的結果,檢測視訊其實也差不多,視訊其實對應於一幀一幀的影象,我們只需要對視訊中每一幀進行檢測,即可得到對視訊的檢測結果。
程式碼如下:
from imutils.video import VideoStream from imutils.video import FPS import numpy as np import argparse import imutils import time import cv2 import time ap = argparse.ArgumentParser() ap.add_argument("-p", "--prototxt", required=True, help="path to Caffe 'deploy' prototxt file") ap.add_argument("-m", "--model", required=True, help="path to Caffe pre-trained model") ap.add_argument("-c", "--confidence", type=float, default=0.2, help="min probability. to filter weak detections") args = vars(ap.parse_args()) # initialize the list of class labels MobileNet SSD was trained to # detect, then generate a set of bounding box colors for each class CLASSES = ["background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3)) # load the serialized model from disk print("[INFO] loading model...") net = cv2.dnn.readNetFromCaffe(args["prototxt"], args["model"]) # initialize video stream, allow camera sensor to warmup and initialize FPS counter video_stream = VideoStream(src=0).start() time.sleep(2.0) fps = FPS().start() while True: # grab the frame from threaded video stream and resize it to a max width of 400 frame = video_stream.read() frame = imutils.resize(frame, width=400) # grab the frame dimensions and convert it to blob (h, w) = frame.shape[:2] blob = cv2.dnn.blobFromImage(cv2.resize(frame, (300, 300)), 0.007843, (300, 300), 127.5) # pass the blob through network net.setInput(blob) detections = net.forward() # loop over the detections for i in np.arange(0, detections.shape[2]): # extract the confidence confidence = detections[0, 0, i, 2] # filter weak detections if confidence > args["confidence"]: # extract index of class label idx = int(detections[0, 0, i, 1]) box = detections[0, 0, i, 3:7] * np.array([w, h, w, h]) (startX, startY, endX, endY) = box.astype("int") # draw predictions in the frame label = "{}: {:.2f}%".format(CLASSES[idx], confidence* 100) cv2.rectangle(frame, (startX, startY), (endX, endY), COLORS[idx], 2) y = startY - 15 if startY - 15 > 15 else startY + 15 cv2.putText(frame, label, (startX, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLORS[idx], 2) # show the output frame cv2.imshow("frame", frame) key = cv2.waitKey(1) & 0xFF # if `q` is pressed, break from loop if key == ord('q'): break # update the fps counter fps.update() # stop the timer and display FPS information fps.stop() print("[INFO] elapsed time: {:.2f}".format(fps.elapsed())) print("[INFO] approx. FPS: {:.2f}".format(fps.fps())) # cleanup cv2.destroyAllWindows() video_stream.stop()
輸入命令:
python object_detection.py -p MobileNetSSD_deploy.prototxt.txt -m MobileNetSSD_deploy.caffemodel
效果如下: