Tensorflow2.X版本keras模型輸出儲存為frozen graph使用OpenCV呼叫
阿新 • • 發佈:2020-12-10
技術標籤:AI/ML/DL
環境:
- windows10 64bit
- python: 3.7
- opencv 4.2.0
- tensorflow: 2.1
**目的:**利用opencv中的dnn模組對tensorflow模型進行載入。
opencv的dnn模組有函式dnn.readNetFromTensorflow
,根據函式文件可知是呼叫pb格式的tensorflow模型,這裡就入坑了,tensorflow儲存的檔案格式多種多種:TFLite, frozen graph, SavedModel, serving model, TFHub representation, Keras's .h5
,tensorflow2.X版本之後推薦使用keras,原版keras預設儲存的模型檔案是.h5
tf.keras
模型的save方法預設儲存格式是tensorflow的SavedModel
(可以通過引數save_format
控制),這種方法報道模型也有一個pb
檔案:但如果使用dnn模組的相關介面去呼叫是無法正確讀入模型的,經過查閱之後發現需要儲存為
frozen graph
格式,目前網上搜到的大部分關於Keras模型轉Frozen grpah的教程所依賴的tensorflow版本都較老,筆者使用的tf版本為最新的2.1版本,最終找到了一個靠譜的方法:
https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/
對應的程式碼在https://github.com/leimao/Frozen_Graph_TensorFlow/blob/master/TensorFlow_v2/train.py
下面筆者提供一份簡單程式碼來演示opencv如何載入tensorflow(keras)模型:
from cv2 import dnn import cv2 import numpy as np import matplotlib.pyplot as plt import os from keras import backend as K from keras.models import load_model #from tensorflow_serving.session_bundle import exporter from keras.models import model_from_config from keras.models import Sequential,Model import tensorflow as tf from tensorflow.python.tools import freeze_graph from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 import os print(tf.__version__) print(cv2.__version__) #%% opencv處理適合模型輸入的圖片 img_file = r"C:\Users\zhou-\Pictures\cat.jpg" img_cv2 = cv2.imread(img_file) print("[INFO]Image shape: ", img_cv2.shape) # 主要圖片尺寸要和模型輸入匹配(mobilenet要求輸入的尺寸為224*224) inWidth = 224 inHeight = 224 blob = cv2.dnn.blobFromImage(img_cv2, scalefactor=1.0 / 255, size=(inWidth, inHeight), mean=(0, 0, 0), swapRB=False, crop=False) # blob = np.transpose(blob, (0,2,3,1)) # 適合keras mobilenet網路輸入格式 print("[INFO]img shape: ", blob.shape) #%% 儲存keras模型為SaveModel會報錯,相關issue見: # https://github.com/opencv/opencv/issues/16582 model = tf.keras.applications.mobilenet.MobileNet(weights='imagenet') # model.save('my_model', save_format='tf') # Save model to SavedModel format # 參考https://github.com/leimao/Frozen_Graph_TensorFlow/blob/master/TensorFlow_v2/train.py # Save model to SavedModel format # tf.saved_model.save(model, r"./models") # Convert Keras model to ConcreteFunction full_model = tf.function(lambda x: model(x)) full_model = full_model.get_concrete_function( tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)) # Get frozen ConcreteFunction frozen_func = convert_variables_to_constants_v2(full_model) frozen_func.graph.as_graph_def() layers = [op.name for op in frozen_func.graph.get_operations()] print("-" * 50) print("Frozen model layers: ") for layer in layers: print(layer) print("-" * 50) print("Frozen model inputs: ") print(frozen_func.inputs) print("Frozen model outputs: ") print(frozen_func.outputs) # Save frozen graph from frozen ConcreteFunction to hard drive tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir="./frozen_models", name="frozen_graph.pb", as_text=False) net = dnn.readNetFromTensorflow('frozen_models/frozen_graph.pb') # Run a model net.setInput(blob) out = net.forward() # Get a class with a highest score. out = out.flatten() classId = np.argmax(out) confidence = out[classId] # Put efficiency information. t, _ = net.getPerfProfile() label = 'Inference time: %.2f ms' % (t * 1000.0 / cv2.getTickFrequency()) cv2.putText(img_cv2, label, (0, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0)) # Print predicted class. def load_imagenet_classes(file_path): ''' imagenet對應的標籤資料如下所示: 0: 'tench, Tinca tinca', 1: 'goldfish, Carassius auratus', ... ''' classes = [] contents = None with open(file_path,'r') as f: contents = f.readlines() for cnt in contents: cnt = cnt.strip() classes.append(cnt.split(':')[1].strip().replace(',','')) return classes classes = load_imagenet_classes('imagenet_classes.txt') label = '%s: %.4f' % (classes[classId] if classes else 'Class #%d' % classId, confidence) cv2.putText(img_cv2, label, (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0)) cv2.imwrite('output-{}.png'.format(img_file.split('\\')[-1][:-4]), img_cv2)