mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Add 2 C API for ort extension (#19808)
### Description <!-- Describe your changes. --> Add 2 C API for ORT extension: - KernelInfo_GetAllocator - OrtCustomOp::GetMayInplace ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Add 2 C API for ORT extension project, which will leverage these 2 APIs for GroupQueryAttention custom op.
This commit is contained in:
parent
409b811325
commit
966fa74597
5 changed files with 63 additions and 0 deletions
|
|
@ -4600,6 +4600,16 @@ struct OrtApi {
|
|||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*/
|
||||
ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
|
||||
|
||||
/** \brief Get allocator from KernelInfo for a specific memory type. Please use C API ReleaseAllocator to release out object
|
||||
*
|
||||
* \param[in] info OrtKernelInfo instance
|
||||
* \param[in] mem_type OrtMemType object
|
||||
* \param[out] out A pointer to OrtAllocator
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*/
|
||||
ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out);
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
@ -4697,6 +4707,13 @@ struct OrtCustomOp {
|
|||
// Get start range
|
||||
int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op);
|
||||
int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op);
|
||||
|
||||
// Get the inplace_map that defines which output can reuse which input
|
||||
// Callers will provide 2 raw int* and pass in their address, this function will fill these 2 arrays
|
||||
// when return, output (*output_index)[i] may reuse the input (*input_index[i]).
|
||||
// The return value is the size of these 2 arrays.
|
||||
// Callers are responsible to delete these 2 arrays after use.
|
||||
size_t(ORT_API_CALL* GetMayInplace)(_Out_ int** input_index, _Out_ int** output_index);
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -736,6 +736,18 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetLogger, _In_ const OrtKernelInfo* inf
|
|||
});
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out) {
|
||||
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
|
||||
onnxruntime::AllocatorPtr allocator = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAllocator(mem_type);
|
||||
if (!allocator) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available");
|
||||
}
|
||||
auto p = std::make_unique<onnxruntime::OrtAllocatorImplWrappingIAllocator>(std::move(allocator));
|
||||
*out = p.release();
|
||||
return nullptr;
|
||||
});
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) {
|
||||
if (count_or_bytes == 0) {
|
||||
*out = nullptr;
|
||||
|
|
|
|||
|
|
@ -2726,6 +2726,7 @@ static constexpr OrtApi ort_api_1_to_18 = {
|
|||
&OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2,
|
||||
&OrtApis::SessionOptionsAppendExecutionProvider_VitisAI,
|
||||
&OrtApis::KernelContext_GetScratchBuffer,
|
||||
&OrtApis::KernelInfoGetAllocator,
|
||||
};
|
||||
|
||||
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
|
||||
|
|
|
|||
|
|
@ -515,4 +515,6 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessi
|
|||
_In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);
|
||||
|
||||
ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
|
||||
|
||||
ORT_API_STATUS_IMPL(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out);
|
||||
} // namespace OrtApis
|
||||
|
|
|
|||
|
|
@ -4007,3 +4007,34 @@ TEST(CApiTest, RunAsyncFail) {
|
|||
Ort::RunOptions run_options;
|
||||
EXPECT_THROW(session.RunAsync(run_options, input_names, input_tensors, 1, output_names, output_values, 1, CallbackFail, nullptr), std::exception);
|
||||
}
|
||||
|
||||
struct MockGQA : public OrtCustomOp {
|
||||
MockGQA() {
|
||||
OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) {
|
||||
size_t ret = 2;
|
||||
*input_index = static_cast<int*>(malloc(ret * sizeof(int)));
|
||||
(*input_index)[0] = 3;
|
||||
(*input_index)[1] = 4;
|
||||
*output_index = static_cast<int*>(malloc(ret * sizeof(int)));
|
||||
(*output_index)[0] = 1;
|
||||
(*output_index)[1] = 2;
|
||||
return ret;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
TEST(CApiTest, OrtCustomOp_GetInPlace) {
|
||||
MockGQA mock_gqa;
|
||||
int* input_index = nullptr;
|
||||
int* output_index = nullptr;
|
||||
size_t len = mock_gqa.GetMayInplace(&input_index, &output_index);
|
||||
ASSERT_NE(input_index, nullptr);
|
||||
ASSERT_NE(output_index, nullptr);
|
||||
ASSERT_EQ(input_index[0], 3);
|
||||
ASSERT_EQ(input_index[1], 4);
|
||||
ASSERT_EQ(output_index[0], 1);
|
||||
ASSERT_EQ(output_index[1], 2);
|
||||
ASSERT_EQ(len, static_cast<size_t>(2));
|
||||
free(input_index);
|
||||
free(output_index);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue