Allow custom ops to set input memory type (#10879)

This commit is contained in:
Fei Hu 2022-10-28 21:45:26 -07:00 committed by GitHub
parent 1b494daffa
commit 943e156f4c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 138 additions and 0 deletions

View file

@ -3643,6 +3643,13 @@ struct OrtCustomOp {
// Returns the characteristics of the input & output tensors
OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetInputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetOutputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
// Returns the memory type of the input tensors. This API allows the custom op
// to place the inputs on specific devices. By default, it returns
// OrtMemTypeDefault, which means the input is placed on the default device for
// the execution provider. If the inputs need to be with different memory tyeps,
// this function can be overriden to return the specific memory types.
OrtMemType(ORT_API_CALL* GetInputMemoryType)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
};
/*

View file

@ -1649,6 +1649,7 @@ struct CustomOpBase : OrtCustomOp {
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
OrtCustomOp::GetInputMemoryType= [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
@ -1678,6 +1679,11 @@ struct CustomOpBase : OrtCustomOp {
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
// Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
OrtMemType GetInputMemoryType(size_t /*index*/) const {
return OrtMemTypeDefault;
}
};
} // namespace Ort

View file

@ -248,6 +248,11 @@ common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domai
.SetDomain(domain->domain_)
.SinceVersion(1);
auto input_count = op->GetInputTypeCount(op);
for (size_t i = 0; i < input_count; i++) {
def_builder.InputMemoryType(op->GetInputMemoryType(op, i), i);
}
for (auto& id : type_constraint_ids[op]) {
def_builder.TypeConstraint(id, DataTypeImpl::AllTensorTypes());
}

View file

@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "custom_op_utils.h"
#include "core/common/common.h"
@ -49,6 +51,54 @@ void MyCustomKernel::Compute(OrtKernelContext* context) {
#endif
}
#ifdef USE_CUDA
void MyCustomKernelSecondInputOnCpu::Compute(OrtKernelContext* context) {
// Setup inputs
Ort::KernelContext ctx(context);
auto input_X = ctx.GetInput(0);
auto input_Y = ctx.GetInput(1);
const float* X = input_X.GetTensorData<float>();
const float* Y = input_Y.GetTensorData<float>();
// check if the second input is on CPU
cudaPointerAttributes attributes;
cudaPointerGetAttributes(&attributes, Y);
auto y_mem_type = attributes.device;
// TODO: check why the below ORT API does not work as expected:
// `auto y_mem_type = input_Y.GetTensorMemoryInfo().GetMemoryType();`
ASSERT_EQ(y_mem_type, OrtMemType::OrtMemTypeCPUInput);
// copy the second input to GPU
const int64_t y_size = input_Y.GetTensorTypeAndShapeInfo().GetElementCount();
float* Y_cuda {};
cudaMalloc(&Y_cuda, y_size * sizeof(float));
cudaMemcpy(Y_cuda, Y, y_size * sizeof(float), cudaMemcpyHostToDevice);
// Setup output
auto dimensions = input_X.GetTensorTypeAndShapeInfo().GetShape();
auto output = ctx.GetOutput(0, dimensions);
float* out = output.GetTensorMutableData<float>();
auto output_info = output.GetTensorTypeAndShapeInfo();
int64_t size = output_info.GetElementCount();
// Do computation
// Launch on stream 0 or user provided stream
cuda_add(size, out, X, Y_cuda, compute_stream_ == nullptr ? 0 : reinterpret_cast<cudaStream_t>(compute_stream_));
// cudaStreamSynchronize(nullptr);
// If everything is setup correctly, custom op implementations need not have such explicit synchronization logic as above.
// To make sure custom kernels and ORT CUDA kernels are implicitly synchronized:
// (1) Create your session with a compute stream passed in via SessionOptions and use the same compute
// stream to launch the custom op (OR)
// (2) Use the API KernelContext_GetGPUComputeStream() to query the CUDA compute stream being used by ORT kernels in this session
// and use the same compute stream to launch the custom op.
// Here, an example for (1) is shown (See test_inference.cc to see how this custom op is used.)
cudaFree(Y_cuda);
}
#endif
void MyCustomKernelMultipleDynamicInputs::Compute(OrtKernelContext* context) {
// Setup inputs
Ort::KernelContext ctx(context);

View file

@ -25,6 +25,17 @@ struct MyCustomKernel {
void* compute_stream_;
};
struct MyCustomKernelSecondInputOnCpu {
MyCustomKernelSecondInputOnCpu(const OrtKernelInfo* /*info*/, void* compute_stream)
: compute_stream_(compute_stream) {
}
void Compute(OrtKernelContext* context);
private:
void* compute_stream_;
};
struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> {
explicit MyCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {}
@ -46,6 +57,32 @@ struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> {
void* compute_stream_;
};
struct MyCustomOpSecondInputOnCpu : Ort::CustomOpBase<MyCustomOpSecondInputOnCpu, MyCustomKernelSecondInputOnCpu> {
explicit MyCustomOpSecondInputOnCpu(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {}
void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { return new MyCustomKernelSecondInputOnCpu(info, compute_stream_); };
const char* GetName() const { return "Foo"; };
const char* GetExecutionProviderType() const { return provider_; };
size_t GetInputTypeCount() const { return 2; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
// Both the inputs need to be necessarily of float type
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
};
OrtMemType GetInputMemoryType(size_t i) const {
if (i == 1) { return OrtMemTypeCPUInput; }
return OrtMemTypeDefault;
};
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
private:
const char* provider_{"CUDAExecutionProvider"};
void* compute_stream_;
};
struct MyCustomKernelMultipleDynamicInputs {
MyCustomKernelMultipleDynamicInputs(const OrtKernelInfo* /*info*/, void* compute_stream)
: compute_stream_(compute_stream) {

View file

@ -402,6 +402,39 @@ TEST(CApiTest, custom_op_handler) {
#endif
}
#ifdef USE_CUDA
TEST(CApiTest, custom_op_set_input_memory_type) {
std::cout << "Running custom op inference" << std::endl;
std::vector<Input> inputs(1);
Input& input = inputs[0];
input.name = "X";
input.dims = {3, 2};
input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
// prepare expected inputs and outputs
std::vector<int64_t> expected_dims_y = {3, 2};
std::vector<float> expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};
cudaStream_t compute_stream = nullptr;
cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking);
MyCustomOpSecondInputOnCpu custom_op{onnxruntime::kCudaExecutionProvider, compute_stream};
Ort::CustomOpDomain custom_op_domain("");
custom_op_domain.Add(&custom_op);
auto x_mem_type = custom_op.GetInputMemoryType(0);
auto y_mem_type = custom_op.GetInputMemoryType(1);
ASSERT_EQ(x_mem_type, OrtMemType::OrtMemTypeDefault);
ASSERT_EQ(y_mem_type, OrtMemType::OrtMemTypeCPUInput);
TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1,
custom_op_domain, nullptr, nullptr, false, compute_stream);
cudaStreamDestroy(compute_stream);
}
#endif
#if !defined(ORT_MINIMAL_BUILD) && !defined(REDUCED_OPS_BUILD)
// disable test in reduced-op-build since TOPK and GRU are excluded there
TEST(CApiTest, standalone_op_handler) {