1. 程式人生 > 其它 >Tensorflow2.X版本keras模型輸出儲存為frozen graph使用OpenCV呼叫

Tensorflow2.X版本keras模型輸出儲存為frozen graph使用OpenCV呼叫

技術標籤:AI/ML/DL

環境:

  1. windows10 64bit
  2. python: 3.7
  3. opencv 4.2.0
  4. tensorflow: 2.1

**目的:**利用opencv中的dnn模組對tensorflow模型進行載入。

opencv的dnn模組有函式dnn.readNetFromTensorflow,根據函式文件可知是呼叫pb格式的tensorflow模型,這裡就入坑了,tensorflow儲存的檔案格式多種多種:TFLite, frozen graph, SavedModel, serving model, TFHub representation, Keras's .h5,tensorflow2.X版本之後推薦使用keras,原版keras預設儲存的模型檔案是.h5

格式的,而tf.keras模型的save方法預設儲存格式是tensorflow的SavedModel(可以通過引數save_format控制),這種方法報道模型也有一個pb檔案:
在這裡插入圖片描述
但如果使用dnn模組的相關介面去呼叫是無法正確讀入模型的,經過查閱之後發現需要儲存為frozen graph格式,目前網上搜到的大部分關於Keras模型轉Frozen grpah的教程所依賴的tensorflow版本都較老,筆者使用的tf版本為最新的2.1版本,最終找到了一個靠譜的方法:

https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/

對應的程式碼在https://github.com/leimao/Frozen_Graph_TensorFlow/blob/master/TensorFlow_v2/train.py

下面筆者提供一份簡單程式碼來演示opencv如何載入tensorflow(keras)模型:

from cv2 import dnn
import cv2
import numpy as np 
import matplotlib.pyplot as plt
import os
from keras import backend as K
from keras.models import load_model
#from tensorflow_serving.session_bundle import exporter
from keras.models import model_from_config
from keras.models import Sequential,Model
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os

print(tf.__version__)
print(cv2.__version__)
#%% opencv處理適合模型輸入的圖片
img_file = r"C:\Users\zhou-\Pictures\cat.jpg"
img_cv2 = cv2.imread(img_file)
print("[INFO]Image shape: ", img_cv2.shape)

# 主要圖片尺寸要和模型輸入匹配(mobilenet要求輸入的尺寸為224*224)
inWidth = 224
inHeight = 224
blob = cv2.dnn.blobFromImage(img_cv2,
                                scalefactor=1.0 / 255,
                                size=(inWidth, inHeight),
                                mean=(0, 0, 0),
                                swapRB=False,
                                crop=False)
# blob = np.transpose(blob, (0,2,3,1)) # 適合keras mobilenet網路輸入格式
print("[INFO]img shape: ", blob.shape)

#%% 儲存keras模型為SaveModel會報錯,相關issue見:
# https://github.com/opencv/opencv/issues/16582
model = tf.keras.applications.mobilenet.MobileNet(weights='imagenet')
# model.save('my_model', save_format='tf') # Save model to SavedModel format

# 參考https://github.com/leimao/Frozen_Graph_TensorFlow/blob/master/TensorFlow_v2/train.py

# Save model to SavedModel format
# tf.saved_model.save(model, r"./models")

# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
    tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
    print(layer)

print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)

# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir="./frozen_models",
                  name="frozen_graph.pb",
                  as_text=False)

net = dnn.readNetFromTensorflow('frozen_models/frozen_graph.pb')

# Run a model
net.setInput(blob)
out = net.forward()

# Get a class with a highest score.
out = out.flatten()
classId = np.argmax(out)
confidence = out[classId]

# Put efficiency information.
t, _ = net.getPerfProfile()
label = 'Inference time: %.2f ms' % (t * 1000.0 / cv2.getTickFrequency())
cv2.putText(img_cv2, label, (0, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))

# Print predicted class.
def load_imagenet_classes(file_path):
    '''
    imagenet對應的標籤資料如下所示:
    0: 'tench, Tinca tinca',
    1: 'goldfish, Carassius auratus',
    ...
    '''
    classes = []
    contents = None
    with open(file_path,'r') as f:
        contents = f.readlines()
    for cnt in contents:
        cnt = cnt.strip()
        classes.append(cnt.split(':')[1].strip().replace(',',''))
    
    return classes
        
classes = load_imagenet_classes('imagenet_classes.txt')

label = '%s: %.4f' % (classes[classId] if classes else 'Class #%d' % classId, confidence)
cv2.putText(img_cv2, label, (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))

cv2.imwrite('output-{}.png'.format(img_file.split('\\')[-1][:-4]), img_cv2)