mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
Allow custom ops to set input memory type (#10879)
This commit is contained in:
parent
1b494daffa
commit
943e156f4c
6 changed files with 138 additions and 0 deletions
|
|
@ -3643,6 +3643,13 @@ struct OrtCustomOp {
|
|||
// Returns the characteristics of the input & output tensors
|
||||
OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetInputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
|
||||
OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetOutputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
|
||||
|
||||
// Returns the memory type of the input tensors. This API allows the custom op
|
||||
// to place the inputs on specific devices. By default, it returns
|
||||
// OrtMemTypeDefault, which means the input is placed on the default device for
|
||||
// the execution provider. If the inputs need to be with different memory tyeps,
|
||||
// this function can be overriden to return the specific memory types.
|
||||
OrtMemType(ORT_API_CALL* GetInputMemoryType)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -1649,6 +1649,7 @@ struct CustomOpBase : OrtCustomOp {
|
|||
|
||||
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
|
||||
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
|
||||
OrtCustomOp::GetInputMemoryType= [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
|
||||
|
||||
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
|
||||
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
|
||||
|
|
@ -1678,6 +1679,11 @@ struct CustomOpBase : OrtCustomOp {
|
|||
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
|
||||
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
||||
}
|
||||
|
||||
// Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
|
||||
OrtMemType GetInputMemoryType(size_t /*index*/) const {
|
||||
return OrtMemTypeDefault;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace Ort
|
||||
|
|
|
|||
|
|
@ -248,6 +248,11 @@ common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domai
|
|||
.SetDomain(domain->domain_)
|
||||
.SinceVersion(1);
|
||||
|
||||
auto input_count = op->GetInputTypeCount(op);
|
||||
for (size_t i = 0; i < input_count; i++) {
|
||||
def_builder.InputMemoryType(op->GetInputMemoryType(op, i), i);
|
||||
}
|
||||
|
||||
for (auto& id : type_constraint_ids[op]) {
|
||||
def_builder.TypeConstraint(id, DataTypeImpl::AllTensorTypes());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "custom_op_utils.h"
|
||||
#include "core/common/common.h"
|
||||
|
||||
|
|
@ -49,6 +51,54 @@ void MyCustomKernel::Compute(OrtKernelContext* context) {
|
|||
#endif
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
void MyCustomKernelSecondInputOnCpu::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
Ort::KernelContext ctx(context);
|
||||
auto input_X = ctx.GetInput(0);
|
||||
auto input_Y = ctx.GetInput(1);
|
||||
const float* X = input_X.GetTensorData<float>();
|
||||
const float* Y = input_Y.GetTensorData<float>();
|
||||
|
||||
// check if the second input is on CPU
|
||||
cudaPointerAttributes attributes;
|
||||
cudaPointerGetAttributes(&attributes, Y);
|
||||
auto y_mem_type = attributes.device;
|
||||
// TODO: check why the below ORT API does not work as expected:
|
||||
// `auto y_mem_type = input_Y.GetTensorMemoryInfo().GetMemoryType();`
|
||||
ASSERT_EQ(y_mem_type, OrtMemType::OrtMemTypeCPUInput);
|
||||
|
||||
// copy the second input to GPU
|
||||
const int64_t y_size = input_Y.GetTensorTypeAndShapeInfo().GetElementCount();
|
||||
float* Y_cuda {};
|
||||
cudaMalloc(&Y_cuda, y_size * sizeof(float));
|
||||
cudaMemcpy(Y_cuda, Y, y_size * sizeof(float), cudaMemcpyHostToDevice);
|
||||
|
||||
|
||||
// Setup output
|
||||
auto dimensions = input_X.GetTensorTypeAndShapeInfo().GetShape();
|
||||
auto output = ctx.GetOutput(0, dimensions);
|
||||
float* out = output.GetTensorMutableData<float>();
|
||||
|
||||
auto output_info = output.GetTensorTypeAndShapeInfo();
|
||||
int64_t size = output_info.GetElementCount();
|
||||
|
||||
// Do computation
|
||||
|
||||
// Launch on stream 0 or user provided stream
|
||||
cuda_add(size, out, X, Y_cuda, compute_stream_ == nullptr ? 0 : reinterpret_cast<cudaStream_t>(compute_stream_));
|
||||
// cudaStreamSynchronize(nullptr);
|
||||
// If everything is setup correctly, custom op implementations need not have such explicit synchronization logic as above.
|
||||
// To make sure custom kernels and ORT CUDA kernels are implicitly synchronized:
|
||||
// (1) Create your session with a compute stream passed in via SessionOptions and use the same compute
|
||||
// stream to launch the custom op (OR)
|
||||
// (2) Use the API KernelContext_GetGPUComputeStream() to query the CUDA compute stream being used by ORT kernels in this session
|
||||
// and use the same compute stream to launch the custom op.
|
||||
// Here, an example for (1) is shown (See test_inference.cc to see how this custom op is used.)
|
||||
cudaFree(Y_cuda);
|
||||
}
|
||||
#endif
|
||||
|
||||
void MyCustomKernelMultipleDynamicInputs::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
Ort::KernelContext ctx(context);
|
||||
|
|
|
|||
|
|
@ -25,6 +25,17 @@ struct MyCustomKernel {
|
|||
void* compute_stream_;
|
||||
};
|
||||
|
||||
struct MyCustomKernelSecondInputOnCpu {
|
||||
MyCustomKernelSecondInputOnCpu(const OrtKernelInfo* /*info*/, void* compute_stream)
|
||||
: compute_stream_(compute_stream) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
void* compute_stream_;
|
||||
};
|
||||
|
||||
struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> {
|
||||
explicit MyCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {}
|
||||
|
||||
|
|
@ -46,6 +57,32 @@ struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> {
|
|||
void* compute_stream_;
|
||||
};
|
||||
|
||||
struct MyCustomOpSecondInputOnCpu : Ort::CustomOpBase<MyCustomOpSecondInputOnCpu, MyCustomKernelSecondInputOnCpu> {
|
||||
explicit MyCustomOpSecondInputOnCpu(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {}
|
||||
|
||||
void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { return new MyCustomKernelSecondInputOnCpu(info, compute_stream_); };
|
||||
const char* GetName() const { return "Foo"; };
|
||||
const char* GetExecutionProviderType() const { return provider_; };
|
||||
|
||||
size_t GetInputTypeCount() const { return 2; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
|
||||
// Both the inputs need to be necessarily of float type
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
|
||||
OrtMemType GetInputMemoryType(size_t i) const {
|
||||
if (i == 1) { return OrtMemTypeCPUInput; }
|
||||
return OrtMemTypeDefault;
|
||||
};
|
||||
|
||||
size_t GetOutputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
||||
private:
|
||||
const char* provider_{"CUDAExecutionProvider"};
|
||||
void* compute_stream_;
|
||||
};
|
||||
|
||||
struct MyCustomKernelMultipleDynamicInputs {
|
||||
MyCustomKernelMultipleDynamicInputs(const OrtKernelInfo* /*info*/, void* compute_stream)
|
||||
: compute_stream_(compute_stream) {
|
||||
|
|
|
|||
|
|
@ -402,6 +402,39 @@ TEST(CApiTest, custom_op_handler) {
|
|||
#endif
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
TEST(CApiTest, custom_op_set_input_memory_type) {
|
||||
std::cout << "Running custom op inference" << std::endl;
|
||||
|
||||
std::vector<Input> inputs(1);
|
||||
Input& input = inputs[0];
|
||||
input.name = "X";
|
||||
input.dims = {3, 2};
|
||||
input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
|
||||
// prepare expected inputs and outputs
|
||||
std::vector<int64_t> expected_dims_y = {3, 2};
|
||||
std::vector<float> expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};
|
||||
|
||||
cudaStream_t compute_stream = nullptr;
|
||||
cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking);
|
||||
MyCustomOpSecondInputOnCpu custom_op{onnxruntime::kCudaExecutionProvider, compute_stream};
|
||||
|
||||
Ort::CustomOpDomain custom_op_domain("");
|
||||
custom_op_domain.Add(&custom_op);
|
||||
|
||||
auto x_mem_type = custom_op.GetInputMemoryType(0);
|
||||
auto y_mem_type = custom_op.GetInputMemoryType(1);
|
||||
ASSERT_EQ(x_mem_type, OrtMemType::OrtMemTypeDefault);
|
||||
ASSERT_EQ(y_mem_type, OrtMemType::OrtMemTypeCPUInput);
|
||||
|
||||
TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1,
|
||||
custom_op_domain, nullptr, nullptr, false, compute_stream);
|
||||
cudaStreamDestroy(compute_stream);
|
||||
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) && !defined(REDUCED_OPS_BUILD)
|
||||
// disable test in reduced-op-build since TOPK and GRU are excluded there
|
||||
TEST(CApiTest, standalone_op_handler) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue