Add 2 C API for ort extension (#19808)

### Description
<!-- Describe your changes. -->
Add 2 C API for ORT extension:
- KernelInfo_GetAllocator
- OrtCustomOp::GetMayInplace


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Add 2 C API for ORT extension project, which will leverage these 2 APIs
for GroupQueryAttention custom op.
This commit is contained in:
cao lei 2024-03-14 06:00:41 -07:00 committed by GitHub
parent 409b811325
commit 966fa74597
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 63 additions and 0 deletions

View file

@ -4600,6 +4600,16 @@ struct OrtApi {
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
/** \brief Get allocator from KernelInfo for a specific memory type. Please use C API ReleaseAllocator to release out object
*
* \param[in] info OrtKernelInfo instance
* \param[in] mem_type OrtMemType object
* \param[out] out A pointer to OrtAllocator
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out);
};
/*
@ -4697,6 +4707,13 @@ struct OrtCustomOp {
// Get start range
int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op);
int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op);
// Get the inplace_map that defines which output can reuse which input
// Callers will provide 2 raw int* and pass in their address, this function will fill these 2 arrays
// when return, output (*output_index)[i] may reuse the input (*input_index[i]).
// The return value is the size of these 2 arrays.
// Callers are responsible to delete these 2 arrays after use.
size_t(ORT_API_CALL* GetMayInplace)(_Out_ int** input_index, _Out_ int** output_index);
};
/*

View file

@ -736,6 +736,18 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetLogger, _In_ const OrtKernelInfo* inf
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
onnxruntime::AllocatorPtr allocator = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAllocator(mem_type);
if (!allocator) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available");
}
auto p = std::make_unique<onnxruntime::OrtAllocatorImplWrappingIAllocator>(std::move(allocator));
*out = p.release();
return nullptr;
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) {
if (count_or_bytes == 0) {
*out = nullptr;

View file

@ -2726,6 +2726,7 @@ static constexpr OrtApi ort_api_1_to_18 = {
&OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2,
&OrtApis::SessionOptionsAppendExecutionProvider_VitisAI,
&OrtApis::KernelContext_GetScratchBuffer,
&OrtApis::KernelInfoGetAllocator,
};
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.

View file

@ -515,4 +515,6 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessi
_In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);
ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
ORT_API_STATUS_IMPL(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out);
} // namespace OrtApis

View file

@ -4007,3 +4007,34 @@ TEST(CApiTest, RunAsyncFail) {
Ort::RunOptions run_options;
EXPECT_THROW(session.RunAsync(run_options, input_names, input_tensors, 1, output_names, output_values, 1, CallbackFail, nullptr), std::exception);
}
struct MockGQA : public OrtCustomOp {
MockGQA() {
OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) {
size_t ret = 2;
*input_index = static_cast<int*>(malloc(ret * sizeof(int)));
(*input_index)[0] = 3;
(*input_index)[1] = 4;
*output_index = static_cast<int*>(malloc(ret * sizeof(int)));
(*output_index)[0] = 1;
(*output_index)[1] = 2;
return ret;
};
}
};
TEST(CApiTest, OrtCustomOp_GetInPlace) {
MockGQA mock_gqa;
int* input_index = nullptr;
int* output_index = nullptr;
size_t len = mock_gqa.GetMayInplace(&input_index, &output_index);
ASSERT_NE(input_index, nullptr);
ASSERT_NE(output_index, nullptr);
ASSERT_EQ(input_index[0], 3);
ASSERT_EQ(input_index[1], 4);
ASSERT_EQ(output_index[0], 1);
ASSERT_EQ(output_index[1], 2);
ASSERT_EQ(len, static_cast<size_t>(2));
free(input_index);
free(output_index);
}