From 2c63196600c191dbc419db9bb3d04deddc18e35e Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Fri, 30 Oct 2020 12:25:44 -0700 Subject: [PATCH] Custom Op on GPU (#5620) * add case for cpu custom op on gpu * format doc * restrict GPU custom op on Linux GPU CI only * separate cu file to a independent project * fix typo Co-authored-by: RandySheriffH --- cmake/onnxruntime_unittests.cmake | 4 ++-- docs/AddingCustomOp.md | 4 +++- onnxruntime/test/shared_lib/cuda_add.cu | 17 +++++++++++++++++ onnxruntime/test/shared_lib/test_inference.cc | 16 +++++++--------- 4 files changed, 29 insertions(+), 12 deletions(-) create mode 100644 onnxruntime/test/shared_lib/cuda_add.cu diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index f56ea8b5f4..1e8a540fca 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -865,12 +865,12 @@ if (onnxruntime_BUILD_SHARED_LIB) ################################################################# # test inference using shared lib set(onnxruntime_shared_lib_test_LIBS onnxruntime_mocked_allocator onnxruntime_test_utils onnxruntime_common onnx_proto) - if(NOT WIN32) list(APPEND onnxruntime_shared_lib_test_LIBS nsync_cpp) endif() if (onnxruntime_USE_CUDA) - list(APPEND onnxruntime_shared_lib_test_LIBS cudart) + add_library(onnxruntime_shared_lib_test_cuda ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_add.cu) + list(APPEND onnxruntime_shared_lib_test_LIBS onnxruntime_shared_lib_test_cuda cudart) endif() if (CMAKE_SYSTEM_NAME STREQUAL "Android") list(APPEND onnxruntime_shared_lib_test_LIBS ${android_shared_libs}) diff --git a/docs/AddingCustomOp.md b/docs/AddingCustomOp.md index 396d8caf82..baae800b15 100644 --- a/docs/AddingCustomOp.md +++ b/docs/AddingCustomOp.md @@ -11,7 +11,9 @@ You can also compile the custom ops into a shared library and use that to run a The source code for a sample custom op shared library containing two custom kernels is [here](../onnxruntime/test/testdata/custom_op_library/custom_op_library.cc). See [this](../onnxruntime/test/python/onnxruntime_test_python.py) for an example called testRegisterCustomOpsLibrary that uses the Python API to register a shared library that contains custom op kernels. -Currently, the only supported Execution Providers (EPs) for custom ops registered via this approach are the `CUDA` and the `CPU` EPs. +Currently, the only supported Execution Providers (EPs) for custom ops registered via this approach are the `CUDA` and the `CPU` EPs. + +Note that when a model being inferred on gpu, onnxruntime will insert MemcpyToHost op before a cpu custom op and append MemcpyFromHost after to make sure tensor(s) are accessible throughout calling, meaning there are no extra efforts required from custom op developer for the case. ### 2. Using RegisterCustomRegistry API * Implement your kernel and schema (if required) using the OpKernel and OpSchema APIs (headers are in the include folder). diff --git a/onnxruntime/test/shared_lib/cuda_add.cu b/onnxruntime/test/shared_lib/cuda_add.cu new file mode 100644 index 0000000000..301d82ca38 --- /dev/null +++ b/onnxruntime/test/shared_lib/cuda_add.cu @@ -0,0 +1,17 @@ +#include +#include +#include + +using namespace std; + +__global__ void cuda_add_impl(int64_t N, float* O, const float* X, const float* Y) { + auto offset = threadIdx.x; + if (offset < N) { + O[offset] = Y[offset] + X[offset]; + } +} + +void cuda_add(int64_t N, float* O, const float* X, const float* Y) { + cuda_add_impl<<<1, 256>>>(N, O, X, Y); +} + diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index db0094b78f..9aec6c0fe0 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -221,6 +221,8 @@ struct OrtTensorDimensions : std::vector { template constexpr size_t countof(T (&)[N]) { return N; } +void cuda_add(int64_t, float*, const float*, const float*); + struct MyCustomKernel { MyCustomKernel(Ort::CustomOpApi ort, const OrtKernelInfo* /*info*/) : ort_(ort) { } @@ -242,9 +244,13 @@ struct MyCustomKernel { ort_.ReleaseTensorTypeAndShapeInfo(output_info); // Do computation +#ifdef USE_CUDA + cuda_add(size, out, X, Y); +#else for (int64_t i = 0; i < size; i++) { out[i] = X[i] + Y[i]; } +#endif } private: @@ -291,15 +297,7 @@ TEST(CApiTest, custom_op_handler) { custom_op_domain.Add(&custom_op); #ifdef USE_CUDA - // The custom op kernel has a Compute() method that doesn't really use CUDA and can't be used as is - // because it uses the contents of the inputs and writes to the output of the node - // (not possible as is because they are on the device). - // For the purpose of this exercise, it is not really needed to have a Compute() method that uses CUDA. - // We only need to verify if model load succeeds == session creation succeeds == the node is assigned to the CUDA EP. - // It is enough to test for successful session creation because if the custom node wasn't assigned an EP, - // the session creation would fail. Since the custom node is only tied to the CUDA EP (in CUDA-enabled builds), - // if the session creation succeeds, it is assumed that the node got assigned to the CUDA EP. - TestInference(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1, custom_op_domain, nullptr, nullptr, true); + TestInference(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1, custom_op_domain, nullptr, nullptr); #else TestInference(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0, custom_op_domain, nullptr); #endif