Add example of registering custom cuda op as shared lib (#10025)

This commit is contained in:
Ye Wang 2022-01-05 09:22:15 -08:00 committed by GitHub
parent 2078210a1c
commit 2803a9465d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 40 additions and 3 deletions

View file

@ -1165,7 +1165,13 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
endif()
endif()
onnxruntime_add_shared_library_module(custom_op_library ${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
if (onnxruntime_USE_CUDA)
onnxruntime_add_shared_library_module(custom_op_library ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu
${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
target_include_directories(custom_op_library PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
else()
onnxruntime_add_shared_library_module(custom_op_library ${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
endif()
target_include_directories(custom_op_library PRIVATE ${REPO_ROOT}/include)
if(UNIX)
if (APPLE)
@ -1175,8 +1181,10 @@ if(UNIX)
endif()
else()
set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-DEF:${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.def")
target_compile_options(custom_op_library PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler /wd26409>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd26409>")
if (NOT onnxruntime_USE_CUDA)
target_compile_options(custom_op_library PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler /wd26409>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd26409>")
endif()
endif()
set_property(TARGET custom_op_library APPEND_STRING PROPERTY LINK_FLAGS ${ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG})

View file

@ -565,6 +565,12 @@ namespace Microsoft.ML.OnnxRuntime.Tests
string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), libName);
Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist.");
var ortEnvInstance = OrtEnv.Instance();
string[] providers = ortEnvInstance.GetAvailableProviders();
if (Array.Exists(providers, provider => provider == "CUDAExecutionProvider")) {
option.AppendExecutionProvider_CUDA(0);
}
IntPtr libraryHandle = IntPtr.Zero;
try
{

View file

@ -1129,6 +1129,9 @@ public class InferenceTest {
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testLoadCustomLibrary");
SessionOptions options = new SessionOptions()) {
options.registerCustomOpLibrary(customLibraryName);
if (OnnxRuntime.extractCUDA()) {
options.addCUDA();
}
try (OrtSession session = env.createSession(customOpLibraryTestModel, options)) {
Map<String, OnnxTensor> container = new HashMap<>();

View file

@ -778,8 +778,13 @@ lib_name = "./libcustom_op_library.so";
#endif
void* library_handle = nullptr;
#ifdef USE_CUDA
TestInference<int32_t>(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y,
expected_values_y, 1, nullptr, lib_name.c_str(), &library_handle);
#else
TestInference<int32_t>(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y,
expected_values_y, 0, nullptr, lib_name.c_str(), &library_handle);
#endif
#ifdef _WIN32
bool success = ::FreeLibrary(reinterpret_cast<HMODULE>(library_handle));

View file

@ -8,6 +8,12 @@
#include <cmath>
#include <mutex>
#ifdef USE_CUDA
#include <cuda_runtime.h>
template <typename T1, typename T2, typename T3>
void cuda_add(int64_t, T3*, const T1*, const T2*, cudaStream_t compute_stream);
#endif
static const char* c_OpDomain = "test.customop";
struct OrtCustomOpDomainDeleter {
@ -63,9 +69,14 @@ struct KernelOne {
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// Do computation
#ifdef USE_CUDA
cudaStream_t stream = reinterpret_cast<cudaStream_t>(ort_.KernelContext_GetGPUComputeStream(context));
cuda_add(size, out, X, Y, stream);
#else
for (int64_t i = 0; i < size; i++) {
out[i] = X[i] + Y[i];
}
#endif
}
private:
@ -112,6 +123,10 @@ struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
const char* GetName() const { return "CustomOpOne"; };
#ifdef USE_CUDA
const char* GetExecutionProviderType() const { return "CUDAExecutionProvider"; };
#endif
size_t GetInputTypeCount() const { return 2; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };