1. 程式人生 > 其它 >pytorch自定義運算元

pytorch自定義運算元

參照官方教程,實現pytorch自定義運算元。主要分為以下幾步:

  • 改寫運算元為torch C++版本
  • 註冊運算元
  • 編譯運算元生成庫檔案
  • 呼叫自定義運算元

一、改寫運算元

這裡參照官網例子,結合openCV實現仿射變換,C++程式碼如下:

點選展開warpPerspective.cpp
#include "torch/script.h"
#include "opencv2/opencv.hpp"

torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
    // BEGIN image_mat
    cv::Mat image_mat(/*rows=*/image.size(0),
        /*cols=*/image.size(1),
        /*type=*/CV_32FC1,
        /*data=*/image.data_ptr<float>());
    // END image_mat

    // BEGIN warp_mat
    cv::Mat warp_mat(/*rows=*/warp.size(0),
        /*cols=*/warp.size(1),
        /*type=*/CV_32FC1,
        /*data=*/warp.data_ptr<float>());
    // END warp_mat

    // BEGIN output_mat
    cv::Mat output_mat;
    cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{ 8, 8 });
    // END output_mat

    // BEGIN output_tensor
    torch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{ 8, 8 });
    return output.clone();
    // END output_tensor
}

二、註冊運算元

在warpPerspective.cpp檔案末尾即warp_perspective函式後面加入如下程式碼,注意pytorch版本不同,註冊方式不一樣。 1.6.0及以後的版本在include/torch/目錄下才有library.h檔案,可以採用TORCH_LIBRARY。而之前的版本可以採用torch::RegisterOperators

//static auto registry = torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);  // torch.__version__: 1.5.0

//// torch.__version__ >= 1.6.0  torch/include/torch/library.h
TORCH_LIBRARY(my_ops, m) {
    m.def("warp_perspective", warp_perspective);
}

三、編譯運算元生成庫檔案

編譯成庫檔案有三種方式:

方式一:通過CMake編譯

方式二: 通過torch的JIT編譯

方式三:通過Setuptools編譯

方式一、CMake編譯

這裡分別在win10和Ubuntu18.04下進行編譯,CMakeLists.txt檔案如下,注意win10下需要把相關依賴庫拷貝到相應生成目錄,後面呼叫的時候才能正常執行。

點選展開CMakeLists.txt
# ref: https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html

cmake_minimum_required(VERSION 3.9 FATAL_ERROR)
project(warp_perspective)

set(CMAKE_VERBOSE_MAKEFILE ON)
# >>> build type 
set(CMAKE_BUILD_TYPE "Release")				# 指定生成的版本
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")
# <<<

if(WIN32)
    # windows10
    # Torch
    set(TORCH_ROOT "D:/Anaconda3/envs/Test374/Lib/site-packages/torch")  # 我這裡採用虛擬環境安裝的pytorch
    include_directories(${TORCH_ROOT}/include)
    link_directories(${TORCH_ROOT}/lib/)
    set(TORCH_LIBRARIES "${TORCH_ROOT}/lib/*.lib")   # 可以選擇需要的庫

    # Opencv
    set(OPENCV_ROOT "D:/AI/Classify/C++/opencv")
    include_directories(${OPENCV_ROOT}/include)
    link_directories(${OPENCV_ROOT}/lib/x64/)

    # Define our library target
    add_library(warp_perspective SHARED warpPerspective.cpp)

    # Enable C++14
    target_compile_features(warp_perspective PRIVATE cxx_std_14)

    # Link against Torch
    target_link_libraries(warp_perspective "${TORCH_LIBRARIES}")

    # Link against OpenCV
    target_link_libraries(warp_perspective 
    	opencv_world420
    )
elseif(UNIX)
    # Ubuntu18.04
    # Torch
    set(TORCH_ROOT "/home/zjh/anaconda3/envs/Test374/lib/python3.7/site-packages/torch")   
    include_directories(${TORCH_ROOT}/include)
    link_directories(${TORCH_ROOT}/lib/)

    # Opencv
    set(OpenCV_DIR "/home/zjh/learn/libtorch/Examples/opencv")
    include_directories(${OpenCV_DIR}/include)
    link_directories(${OpenCV_DIR}/lib/Linux64/)

    # Define our library target
    add_library(warp_perspective SHARED warpPerspective.cpp)

    # Enable C++14
    target_compile_features(warp_perspective PRIVATE cxx_std_14)

    # libtorch庫檔案
    target_link_libraries(warp_perspective 
        # CPU
        c10 
        torch_cpu
        # GPU
        c10_cuda 
        torch_cuda
    )

    # opencv庫檔案
    target_link_libraries(warp_perspective
        opencv_core 
        opencv_imgproc
    )
endif()

# windows需要把相關依賴庫copy到編譯目錄下
if (MSVC)
  file(GLOB OPENCV_DLLS "${OPENCV_ROOT}/bin/x64/opencv_world420.dll")
  add_custom_command(TARGET warp_perspective
                     POST_BUILD
                     COMMAND ${CMAKE_COMMAND} -E copy_if_different
                     ${OPENCV_DLLS}
                     $<TARGET_FILE_DIR:warp_perspective>)
endif (MSVC)

方式二、jit compilation

該方式在Linux下需要將opencv相關的庫檔案,放置在/usr/local/lib下才能執行通過,編寫jitCompilation.py如下,然後執行即可生成相應的庫檔案。

點選展開jitCompilation.py
import torch.utils.cpp_extension

torch.utils.cpp_extension.load(
    name="warp_perspective",
    sources=["warpPerspective.cpp"],
    extra_ldflags=["-lopencv_core", "-lopencv_imgproc"],
    is_python_module=False,
    verbose=True,
    extra_include_paths=["/home/learn/libtorch/opencv/include"],
)
print(torch.ops.my_ops.warp_perspective)

print(torch.ops.my_ops.warp_perspective(torch.randn(32, 32), torch.rand(3, 3)))

方式三、setupTools

編寫setup.py,然後執行命令python setup.py build develop生成對應的庫。

點選展開setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension

setup(
    name="warp_perspective",
    ext_modules=[
        CppExtension(
            "warp_perspective",
            ["warp_perspective.cpp"],
            libraries=["opencv_core", "opencv_imgproc"],
            include_dirs=["/home/learn/opencv/include"],
            library_dirs=["/home/learn/opencv/lib/Linux64"]
        )
    ],
    cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
)

四、呼叫

  • windows10

利用CMake方式構建後在build資料夾下會有一個.sln檔案,用visual studio開啟後點擊生成,相應的Release資料夾下會生成warp_perspective.dll

import torch
print(torch.__version__)
torch.ops.load_library("./warp_perspective.dll")
print(torch.__version__)
print(torch.ops.my_ops.warp_perspective)
print(torch.ops.my_ops.warp_perspective(torch.randn(32, 32), torch.rand(3, 3)))
  • Linux

編譯完成後會生成相應的動態庫(so檔案),呼叫結果如下。

注意: 直接執行可能會出現以下錯誤

五、轉onnx

主要是把自定義運算元利用torch.onnx.register_custom_op_symbolic函式將自定義運算元註冊進行註冊,然後匯出onnx模型即可。如果用onnxruntime呼叫匯出的模型,則會報test_custom未定義,可以參照PyTorchCustomOperator進行改寫。

點選展開export.py
import torch
torch.ops.load_library("./testCustom.so")


class MyNet(torch.nn.Module):
    def __init__(self, num_classes):
        super(MyNet, self).__init__()
        self.num_classes = num_classes

    def forward(self, xyz, other):
        return torch.ops.my_ops.test_custom(xyz, other)


def my_custom(g, xyz, other):
    return g.op("cus_ops::test_custom", xyz, other)
torch.onnx.register_custom_op_symbolic("my_ops::test_custom", my_custom, 9)


if __name__ == "__main__":
    net = MyNet(2)
    xyz = torch.rand((2, 3))
    other = torch.rand((1, 3))

    print("xyz: ", xyz)
    out = net(xyz, other)
    print("out: ", out)

    # export onnx
    torch.onnx.export(net,
            (xyz, other),
            "./model.onnx",
            input_names=["points", "cate"],
            output_names=["cls_prob"],
            custom_opsets={"cus_ops": 11},
            dynamic_axes={
                "points": {0: "channel", 1: "n_point"},
                "cls_prob": {0: "channel", 1: "n"}
            }
            )
參考連結: https://blog.csdn.net/Artyze/article/details/107642358

參考連結

PyTorchCustomOperator
register-a-custom-operator