diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 11c94655e3..aa94e05a26 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1165,7 +1165,13 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() endif() -onnxruntime_add_shared_library_module(custom_op_library ${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc) +if (onnxruntime_USE_CUDA) + onnxruntime_add_shared_library_module(custom_op_library ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu + ${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc) + target_include_directories(custom_op_library PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +else() + onnxruntime_add_shared_library_module(custom_op_library ${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc) +endif() target_include_directories(custom_op_library PRIVATE ${REPO_ROOT}/include) if(UNIX) if (APPLE) @@ -1175,8 +1181,10 @@ if(UNIX) endif() else() set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-DEF:${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.def") - target_compile_options(custom_op_library PRIVATE "$<$:-Xcompiler /wd26409>" - "$<$>:/wd26409>") + if (NOT onnxruntime_USE_CUDA) + target_compile_options(custom_op_library PRIVATE "$<$:-Xcompiler /wd26409>" + "$<$>:/wd26409>") + endif() endif() set_property(TARGET custom_op_library APPEND_STRING PROPERTY LINK_FLAGS ${ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG}) diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index e60627158c..6e3c9c5f89 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -565,6 +565,12 @@ namespace Microsoft.ML.OnnxRuntime.Tests string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), libName); Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist."); + var ortEnvInstance = OrtEnv.Instance(); + string[] providers = ortEnvInstance.GetAvailableProviders(); + if (Array.Exists(providers, provider => provider == "CUDAExecutionProvider")) { + option.AppendExecutionProvider_CUDA(0); + } + IntPtr libraryHandle = IntPtr.Zero; try { diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 0290258aeb..2fdc0b506d 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1129,6 +1129,9 @@ public class InferenceTest { try (OrtEnvironment env = OrtEnvironment.getEnvironment("testLoadCustomLibrary"); SessionOptions options = new SessionOptions()) { options.registerCustomOpLibrary(customLibraryName); + if (OnnxRuntime.extractCUDA()) { + options.addCUDA(); + } try (OrtSession session = env.createSession(customOpLibraryTestModel, options)) { Map container = new HashMap<>(); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 70878c53d3..00a42f1af1 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -778,8 +778,13 @@ lib_name = "./libcustom_op_library.so"; #endif void* library_handle = nullptr; +#ifdef USE_CUDA + TestInference(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, + expected_values_y, 1, nullptr, lib_name.c_str(), &library_handle); +#else TestInference(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, expected_values_y, 0, nullptr, lib_name.c_str(), &library_handle); +#endif #ifdef _WIN32 bool success = ::FreeLibrary(reinterpret_cast(library_handle)); diff --git a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc index 9bb26b2690..2165596178 100644 --- a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc +++ b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc @@ -8,6 +8,12 @@ #include #include +#ifdef USE_CUDA +#include +template +void cuda_add(int64_t, T3*, const T1*, const T2*, cudaStream_t compute_stream); +#endif + static const char* c_OpDomain = "test.customop"; struct OrtCustomOpDomainDeleter { @@ -63,9 +69,14 @@ struct KernelOne { ort_.ReleaseTensorTypeAndShapeInfo(output_info); // Do computation +#ifdef USE_CUDA + cudaStream_t stream = reinterpret_cast(ort_.KernelContext_GetGPUComputeStream(context)); + cuda_add(size, out, X, Y, stream); +#else for (int64_t i = 0; i < size; i++) { out[i] = X[i] + Y[i]; } +#endif } private: @@ -112,6 +123,10 @@ struct CustomOpOne : Ort::CustomOpBase { const char* GetName() const { return "CustomOpOne"; }; +#ifdef USE_CUDA + const char* GetExecutionProviderType() const { return "CUDAExecutionProvider"; }; +#endif + size_t GetInputTypeCount() const { return 2; }; ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };