mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Add example of registering custom cuda op as shared lib (#10025)
This commit is contained in:
parent
2078210a1c
commit
2803a9465d
5 changed files with 40 additions and 3 deletions
|
|
@ -1165,7 +1165,13 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
onnxruntime_add_shared_library_module(custom_op_library ${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
|
||||
if (onnxruntime_USE_CUDA)
|
||||
onnxruntime_add_shared_library_module(custom_op_library ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu
|
||||
${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
|
||||
target_include_directories(custom_op_library PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
else()
|
||||
onnxruntime_add_shared_library_module(custom_op_library ${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
|
||||
endif()
|
||||
target_include_directories(custom_op_library PRIVATE ${REPO_ROOT}/include)
|
||||
if(UNIX)
|
||||
if (APPLE)
|
||||
|
|
@ -1175,8 +1181,10 @@ if(UNIX)
|
|||
endif()
|
||||
else()
|
||||
set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-DEF:${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.def")
|
||||
target_compile_options(custom_op_library PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler /wd26409>"
|
||||
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd26409>")
|
||||
if (NOT onnxruntime_USE_CUDA)
|
||||
target_compile_options(custom_op_library PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler /wd26409>"
|
||||
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd26409>")
|
||||
endif()
|
||||
endif()
|
||||
set_property(TARGET custom_op_library APPEND_STRING PROPERTY LINK_FLAGS ${ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG})
|
||||
|
||||
|
|
|
|||
|
|
@ -565,6 +565,12 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), libName);
|
||||
Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist.");
|
||||
|
||||
var ortEnvInstance = OrtEnv.Instance();
|
||||
string[] providers = ortEnvInstance.GetAvailableProviders();
|
||||
if (Array.Exists(providers, provider => provider == "CUDAExecutionProvider")) {
|
||||
option.AppendExecutionProvider_CUDA(0);
|
||||
}
|
||||
|
||||
IntPtr libraryHandle = IntPtr.Zero;
|
||||
try
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1129,6 +1129,9 @@ public class InferenceTest {
|
|||
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testLoadCustomLibrary");
|
||||
SessionOptions options = new SessionOptions()) {
|
||||
options.registerCustomOpLibrary(customLibraryName);
|
||||
if (OnnxRuntime.extractCUDA()) {
|
||||
options.addCUDA();
|
||||
}
|
||||
try (OrtSession session = env.createSession(customOpLibraryTestModel, options)) {
|
||||
Map<String, OnnxTensor> container = new HashMap<>();
|
||||
|
||||
|
|
|
|||
|
|
@ -778,8 +778,13 @@ lib_name = "./libcustom_op_library.so";
|
|||
#endif
|
||||
|
||||
void* library_handle = nullptr;
|
||||
#ifdef USE_CUDA
|
||||
TestInference<int32_t>(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y,
|
||||
expected_values_y, 1, nullptr, lib_name.c_str(), &library_handle);
|
||||
#else
|
||||
TestInference<int32_t>(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y,
|
||||
expected_values_y, 0, nullptr, lib_name.c_str(), &library_handle);
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
bool success = ::FreeLibrary(reinterpret_cast<HMODULE>(library_handle));
|
||||
|
|
|
|||
|
|
@ -8,6 +8,12 @@
|
|||
#include <cmath>
|
||||
#include <mutex>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void cuda_add(int64_t, T3*, const T1*, const T2*, cudaStream_t compute_stream);
|
||||
#endif
|
||||
|
||||
static const char* c_OpDomain = "test.customop";
|
||||
|
||||
struct OrtCustomOpDomainDeleter {
|
||||
|
|
@ -63,9 +69,14 @@ struct KernelOne {
|
|||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
#ifdef USE_CUDA
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(ort_.KernelContext_GetGPUComputeStream(context));
|
||||
cuda_add(size, out, X, Y, stream);
|
||||
#else
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out[i] = X[i] + Y[i];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -112,6 +123,10 @@ struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
|||
|
||||
const char* GetName() const { return "CustomOpOne"; };
|
||||
|
||||
#ifdef USE_CUDA
|
||||
const char* GetExecutionProviderType() const { return "CUDAExecutionProvider"; };
|
||||
#endif
|
||||
|
||||
size_t GetInputTypeCount() const { return 2; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue