pytorch自定義運算元
參照官方教程,實現pytorch自定義運算元。主要分為以下幾步:
- 改寫運算元為torch C++版本
- 註冊運算元
- 編譯運算元生成庫檔案
- 呼叫自定義運算元
一、改寫運算元
這裡參照官網例子,結合openCV實現仿射變換,C++程式碼如下:
點選展開warpPerspective.cpp
#include "torch/script.h" #include "opencv2/opencv.hpp" torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) { // BEGIN image_mat cv::Mat image_mat(/*rows=*/image.size(0), /*cols=*/image.size(1), /*type=*/CV_32FC1, /*data=*/image.data_ptr<float>()); // END image_mat // BEGIN warp_mat cv::Mat warp_mat(/*rows=*/warp.size(0), /*cols=*/warp.size(1), /*type=*/CV_32FC1, /*data=*/warp.data_ptr<float>()); // END warp_mat // BEGIN output_mat cv::Mat output_mat; cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{ 8, 8 }); // END output_mat // BEGIN output_tensor torch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{ 8, 8 }); return output.clone(); // END output_tensor }
二、註冊運算元
在warpPerspective.cpp檔案末尾即warp_perspective
函式後面加入如下程式碼,注意pytorch版本不同,註冊方式不一樣。 1.6.0及以後的版本在include/torch/
目錄下才有library.h檔案,可以採用TORCH_LIBRARY
。而之前的版本可以採用torch::RegisterOperators
。
//static auto registry = torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective); // torch.__version__: 1.5.0 //// torch.__version__ >= 1.6.0 torch/include/torch/library.h TORCH_LIBRARY(my_ops, m) { m.def("warp_perspective", warp_perspective); }
三、編譯運算元生成庫檔案
編譯成庫檔案有三種方式:
方式一:通過CMake編譯
方式二: 通過torch的JIT編譯
方式三:通過Setuptools編譯
方式一、CMake編譯
這裡分別在win10和Ubuntu18.04下進行編譯,CMakeLists.txt檔案如下,注意win10下需要把相關依賴庫拷貝到相應生成目錄,後面呼叫的時候才能正常執行。
點選展開CMakeLists.txt
# ref: https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html cmake_minimum_required(VERSION 3.9 FATAL_ERROR) project(warp_perspective) set(CMAKE_VERBOSE_MAKEFILE ON) # >>> build type set(CMAKE_BUILD_TYPE "Release") # 指定生成的版本 set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb") set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall") # <<< if(WIN32) # windows10 # Torch set(TORCH_ROOT "D:/Anaconda3/envs/Test374/Lib/site-packages/torch") # 我這裡採用虛擬環境安裝的pytorch include_directories(${TORCH_ROOT}/include) link_directories(${TORCH_ROOT}/lib/) set(TORCH_LIBRARIES "${TORCH_ROOT}/lib/*.lib") # 可以選擇需要的庫 # Opencv set(OPENCV_ROOT "D:/AI/Classify/C++/opencv") include_directories(${OPENCV_ROOT}/include) link_directories(${OPENCV_ROOT}/lib/x64/) # Define our library target add_library(warp_perspective SHARED warpPerspective.cpp) # Enable C++14 target_compile_features(warp_perspective PRIVATE cxx_std_14) # Link against Torch target_link_libraries(warp_perspective "${TORCH_LIBRARIES}") # Link against OpenCV target_link_libraries(warp_perspective opencv_world420 ) elseif(UNIX) # Ubuntu18.04 # Torch set(TORCH_ROOT "/home/zjh/anaconda3/envs/Test374/lib/python3.7/site-packages/torch") include_directories(${TORCH_ROOT}/include) link_directories(${TORCH_ROOT}/lib/) # Opencv set(OpenCV_DIR "/home/zjh/learn/libtorch/Examples/opencv") include_directories(${OpenCV_DIR}/include) link_directories(${OpenCV_DIR}/lib/Linux64/) # Define our library target add_library(warp_perspective SHARED warpPerspective.cpp) # Enable C++14 target_compile_features(warp_perspective PRIVATE cxx_std_14) # libtorch庫檔案 target_link_libraries(warp_perspective # CPU c10 torch_cpu # GPU c10_cuda torch_cuda ) # opencv庫檔案 target_link_libraries(warp_perspective opencv_core opencv_imgproc ) endif() # windows需要把相關依賴庫copy到編譯目錄下 if (MSVC) file(GLOB OPENCV_DLLS "${OPENCV_ROOT}/bin/x64/opencv_world420.dll") add_custom_command(TARGET warp_perspective POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${OPENCV_DLLS} $<TARGET_FILE_DIR:warp_perspective>) endif (MSVC)
方式二、jit compilation
該方式在Linux
下需要將opencv相關的庫檔案,放置在/usr/local/lib
下才能執行通過,編寫jitCompilation.py
如下,然後執行即可生成相應的庫檔案。
點選展開jitCompilation.py
import torch.utils.cpp_extension
torch.utils.cpp_extension.load(
name="warp_perspective",
sources=["warpPerspective.cpp"],
extra_ldflags=["-lopencv_core", "-lopencv_imgproc"],
is_python_module=False,
verbose=True,
extra_include_paths=["/home/learn/libtorch/opencv/include"],
)
print(torch.ops.my_ops.warp_perspective)
print(torch.ops.my_ops.warp_perspective(torch.randn(32, 32), torch.rand(3, 3)))
方式三、setupTools
編寫setup.py
,然後執行命令python setup.py build develop
生成對應的庫。
點選展開setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name="warp_perspective",
ext_modules=[
CppExtension(
"warp_perspective",
["warp_perspective.cpp"],
libraries=["opencv_core", "opencv_imgproc"],
include_dirs=["/home/learn/opencv/include"],
library_dirs=["/home/learn/opencv/lib/Linux64"]
)
],
cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
)
四、呼叫
- windows10
利用CMake方式構建後在build資料夾下會有一個.sln檔案,用visual studio開啟後點擊生成,相應的Release資料夾下會生成warp_perspective.dll
。
import torch
print(torch.__version__)
torch.ops.load_library("./warp_perspective.dll")
print(torch.__version__)
print(torch.ops.my_ops.warp_perspective)
print(torch.ops.my_ops.warp_perspective(torch.randn(32, 32), torch.rand(3, 3)))
- Linux
編譯完成後會生成相應的動態庫(so檔案),呼叫結果如下。
注意: 直接執行可能會出現以下錯誤
-
錯誤
解決措施
手動把opencv庫放置在同一目錄下,或者在CMakeLists.txt檔案中加入if (MSVC) *** endif (MSVC)
部分
-
錯誤
解決措施:
CMakeLists.txt檔案中加入add_definitions(-D _GLIBCXX_USE_CXX11_ABI=0)
參考連結:https://discuss.pytorch.org/t/undefined-symbol-when-import-lltm-cpp-extension/32627/2
五、轉onnx
主要是把自定義運算元利用torch.onnx.register_custom_op_symbolic
函式將自定義運算元註冊進行註冊,然後匯出onnx模型即可。如果用onnxruntime呼叫匯出的模型,則會報test_custom
未定義,可以參照PyTorchCustomOperator進行改寫。
點選展開export.py
import torch
torch.ops.load_library("./testCustom.so")
class MyNet(torch.nn.Module):
def __init__(self, num_classes):
super(MyNet, self).__init__()
self.num_classes = num_classes
def forward(self, xyz, other):
return torch.ops.my_ops.test_custom(xyz, other)
def my_custom(g, xyz, other):
return g.op("cus_ops::test_custom", xyz, other)
torch.onnx.register_custom_op_symbolic("my_ops::test_custom", my_custom, 9)
if __name__ == "__main__":
net = MyNet(2)
xyz = torch.rand((2, 3))
other = torch.rand((1, 3))
print("xyz: ", xyz)
out = net(xyz, other)
print("out: ", out)
# export onnx
torch.onnx.export(net,
(xyz, other),
"./model.onnx",
input_names=["points", "cate"],
output_names=["cls_prob"],
custom_opsets={"cus_ops": 11},
dynamic_axes={
"points": {0: "channel", 1: "n_point"},
"cls_prob": {0: "channel", 1: "n"}
}
)