diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 43a07f3c1c..ac81d00e7b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3643,6 +3643,13 @@ struct OrtCustomOp { // Returns the characteristics of the input & output tensors OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetInputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetOutputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + + // Returns the memory type of the input tensors. This API allows the custom op + // to place the inputs on specific devices. By default, it returns + // OrtMemTypeDefault, which means the input is placed on the default device for + // the execution provider. If the inputs need to be with different memory tyeps, + // this function can be overriden to return the specific memory types. + OrtMemType(ORT_API_CALL* GetInputMemoryType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 2cdfb993a4..0411501336 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1649,6 +1649,7 @@ struct CustomOpBase : OrtCustomOp { OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetInputTypeCount(); }; OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputType(index); }; + OrtCustomOp::GetInputMemoryType= [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputMemoryType(index); }; OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetOutputTypeCount(); }; OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputType(index); }; @@ -1678,6 +1679,11 @@ struct CustomOpBase : OrtCustomOp { OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; } + + // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault + OrtMemType GetInputMemoryType(size_t /*index*/) const { + return OrtMemTypeDefault; + } }; } // namespace Ort diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 055a801b33..d5671e8103 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -248,6 +248,11 @@ common::Status CreateCustomRegistry(gsl::span op_domai .SetDomain(domain->domain_) .SinceVersion(1); + auto input_count = op->GetInputTypeCount(op); + for (size_t i = 0; i < input_count; i++) { + def_builder.InputMemoryType(op->GetInputMemoryType(op, i), i); + } + for (auto& id : type_constraint_ids[op]) { def_builder.TypeConstraint(id, DataTypeImpl::AllTensorTypes()); } diff --git a/onnxruntime/test/shared_lib/custom_op_utils.cc b/onnxruntime/test/shared_lib/custom_op_utils.cc index 3885f5fbc9..ac0867c73a 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.cc +++ b/onnxruntime/test/shared_lib/custom_op_utils.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "gtest/gtest.h" + #include "custom_op_utils.h" #include "core/common/common.h" @@ -49,6 +51,54 @@ void MyCustomKernel::Compute(OrtKernelContext* context) { #endif } +#ifdef USE_CUDA +void MyCustomKernelSecondInputOnCpu::Compute(OrtKernelContext* context) { + // Setup inputs + Ort::KernelContext ctx(context); + auto input_X = ctx.GetInput(0); + auto input_Y = ctx.GetInput(1); + const float* X = input_X.GetTensorData(); + const float* Y = input_Y.GetTensorData(); + + // check if the second input is on CPU + cudaPointerAttributes attributes; + cudaPointerGetAttributes(&attributes, Y); + auto y_mem_type = attributes.device; + // TODO: check why the below ORT API does not work as expected: + // `auto y_mem_type = input_Y.GetTensorMemoryInfo().GetMemoryType();` + ASSERT_EQ(y_mem_type, OrtMemType::OrtMemTypeCPUInput); + + // copy the second input to GPU + const int64_t y_size = input_Y.GetTensorTypeAndShapeInfo().GetElementCount(); + float* Y_cuda {}; + cudaMalloc(&Y_cuda, y_size * sizeof(float)); + cudaMemcpy(Y_cuda, Y, y_size * sizeof(float), cudaMemcpyHostToDevice); + + + // Setup output + auto dimensions = input_X.GetTensorTypeAndShapeInfo().GetShape(); + auto output = ctx.GetOutput(0, dimensions); + float* out = output.GetTensorMutableData(); + + auto output_info = output.GetTensorTypeAndShapeInfo(); + int64_t size = output_info.GetElementCount(); + + // Do computation + + // Launch on stream 0 or user provided stream + cuda_add(size, out, X, Y_cuda, compute_stream_ == nullptr ? 0 : reinterpret_cast(compute_stream_)); + // cudaStreamSynchronize(nullptr); + // If everything is setup correctly, custom op implementations need not have such explicit synchronization logic as above. + // To make sure custom kernels and ORT CUDA kernels are implicitly synchronized: + // (1) Create your session with a compute stream passed in via SessionOptions and use the same compute + // stream to launch the custom op (OR) + // (2) Use the API KernelContext_GetGPUComputeStream() to query the CUDA compute stream being used by ORT kernels in this session + // and use the same compute stream to launch the custom op. + // Here, an example for (1) is shown (See test_inference.cc to see how this custom op is used.) + cudaFree(Y_cuda); +} +#endif + void MyCustomKernelMultipleDynamicInputs::Compute(OrtKernelContext* context) { // Setup inputs Ort::KernelContext ctx(context); diff --git a/onnxruntime/test/shared_lib/custom_op_utils.h b/onnxruntime/test/shared_lib/custom_op_utils.h index 1add6bb723..e192dc058b 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.h +++ b/onnxruntime/test/shared_lib/custom_op_utils.h @@ -25,6 +25,17 @@ struct MyCustomKernel { void* compute_stream_; }; +struct MyCustomKernelSecondInputOnCpu { + MyCustomKernelSecondInputOnCpu(const OrtKernelInfo* /*info*/, void* compute_stream) + : compute_stream_(compute_stream) { + } + + void Compute(OrtKernelContext* context); + + private: + void* compute_stream_; +}; + struct MyCustomOp : Ort::CustomOpBase { explicit MyCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {} @@ -46,6 +57,32 @@ struct MyCustomOp : Ort::CustomOpBase { void* compute_stream_; }; +struct MyCustomOpSecondInputOnCpu : Ort::CustomOpBase { + explicit MyCustomOpSecondInputOnCpu(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {} + + void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { return new MyCustomKernelSecondInputOnCpu(info, compute_stream_); }; + const char* GetName() const { return "Foo"; }; + const char* GetExecutionProviderType() const { return provider_; }; + + size_t GetInputTypeCount() const { return 2; }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { + // Both the inputs need to be necessarily of float type + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + OrtMemType GetInputMemoryType(size_t i) const { + if (i == 1) { return OrtMemTypeCPUInput; } + return OrtMemTypeDefault; + }; + + size_t GetOutputTypeCount() const { return 1; }; + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + + private: + const char* provider_{"CUDAExecutionProvider"}; + void* compute_stream_; +}; + struct MyCustomKernelMultipleDynamicInputs { MyCustomKernelMultipleDynamicInputs(const OrtKernelInfo* /*info*/, void* compute_stream) : compute_stream_(compute_stream) { diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 316a035dba..b2fb1479b6 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -402,6 +402,39 @@ TEST(CApiTest, custom_op_handler) { #endif } +#ifdef USE_CUDA +TEST(CApiTest, custom_op_set_input_memory_type) { + std::cout << "Running custom op inference" << std::endl; + + std::vector inputs(1); + Input& input = inputs[0]; + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f}; + + cudaStream_t compute_stream = nullptr; + cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking); + MyCustomOpSecondInputOnCpu custom_op{onnxruntime::kCudaExecutionProvider, compute_stream}; + + Ort::CustomOpDomain custom_op_domain(""); + custom_op_domain.Add(&custom_op); + + auto x_mem_type = custom_op.GetInputMemoryType(0); + auto y_mem_type = custom_op.GetInputMemoryType(1); + ASSERT_EQ(x_mem_type, OrtMemType::OrtMemTypeDefault); + ASSERT_EQ(y_mem_type, OrtMemType::OrtMemTypeCPUInput); + + TestInference(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1, + custom_op_domain, nullptr, nullptr, false, compute_stream); + cudaStreamDestroy(compute_stream); + +} +#endif + #if !defined(ORT_MINIMAL_BUILD) && !defined(REDUCED_OPS_BUILD) // disable test in reduced-op-build since TOPK and GRU are excluded there TEST(CApiTest, standalone_op_handler) {