mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
add API function GetAliasMap and ReleaseAliasMap in OrtCustomOp (#20145)
### Description <!-- Describe your changes. --> Add API function GetAliasMap and ReleaseAliasMap in OrtCustomOp ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Add API function GetAliasMap and ReleaseAliasMap in OrtCustomOp
This commit is contained in:
parent
8396845806
commit
604b284261
5 changed files with 43 additions and 0 deletions
|
|
@ -4738,6 +4738,10 @@ struct OrtCustomOp {
|
|||
// Release the pointer input_index and output_index allocated from GetMayInplace() function.
|
||||
// If GetMayInplace() is defined, this function MUST be defined as well.
|
||||
void(ORT_API_CALL* ReleaseMayInplace)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index);
|
||||
|
||||
// Same as GetMayInplace() and ReleaseMayInplace()
|
||||
size_t(ORT_API_CALL* GetAliasMap)(_Out_ int** input_index, _Out_ int** output_index);
|
||||
void(ORT_API_CALL* ReleaseAliasMap)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index);
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -2304,6 +2304,8 @@ struct CustomOpBase : OrtCustomOp {
|
|||
|
||||
OrtCustomOp::GetMayInplace = nullptr;
|
||||
OrtCustomOp::ReleaseMayInplace = nullptr;
|
||||
OrtCustomOp::GetAliasMap = nullptr;
|
||||
OrtCustomOp::ReleaseAliasMap = nullptr;
|
||||
}
|
||||
|
||||
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
|
||||
|
|
|
|||
|
|
@ -865,6 +865,8 @@ struct OrtLiteCustomOp : public OrtCustomOp {
|
|||
|
||||
OrtCustomOp::GetMayInplace = {};
|
||||
OrtCustomOp::ReleaseMayInplace = {};
|
||||
OrtCustomOp::GetAliasMap = {};
|
||||
OrtCustomOp::ReleaseAliasMap = {};
|
||||
}
|
||||
|
||||
const std::string op_name_;
|
||||
|
|
|
|||
|
|
@ -875,6 +875,16 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust
|
|||
}
|
||||
}
|
||||
|
||||
if (op->version >= 18 && op->GetAliasMap != nullptr) {
|
||||
int* input_index = nullptr;
|
||||
int* output_index = nullptr;
|
||||
size_t len = op->GetAliasMap(&input_index, &output_index);
|
||||
if (len > 0) {
|
||||
for (size_t i = 0; i < len; i++) def_builder.Alias(input_index[i], output_index[i]);
|
||||
op->ReleaseAliasMap(input_index, output_index);
|
||||
}
|
||||
}
|
||||
|
||||
KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info,
|
||||
std::unique_ptr<OpKernel>& out) -> Status {
|
||||
out = std::make_unique<CustomOpKernel>(info, *op);
|
||||
|
|
|
|||
|
|
@ -4039,6 +4039,20 @@ struct MockGQA : public OrtCustomOp {
|
|||
free(input_index);
|
||||
free(output_index);
|
||||
};
|
||||
OrtCustomOp::GetAliasMap = [](int** input_index, int** output_index) {
|
||||
size_t ret = 2;
|
||||
*input_index = static_cast<int*>(malloc(ret * sizeof(int)));
|
||||
(*input_index)[0] = 5;
|
||||
(*input_index)[1] = 6;
|
||||
*output_index = static_cast<int*>(malloc(ret * sizeof(int)));
|
||||
(*output_index)[0] = 7;
|
||||
(*output_index)[1] = 8;
|
||||
return ret;
|
||||
};
|
||||
OrtCustomOp::ReleaseAliasMap = [](int* input_index, int* output_index) {
|
||||
free(input_index);
|
||||
free(output_index);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -4055,4 +4069,15 @@ TEST(CApiTest, OrtCustomOp_GetInPlace) {
|
|||
ASSERT_EQ(output_index[1], 2);
|
||||
ASSERT_EQ(len, static_cast<size_t>(2));
|
||||
mock_gqa.ReleaseMayInplace(input_index, output_index);
|
||||
|
||||
input_index = output_index = nullptr;
|
||||
len = mock_gqa.GetAliasMap(&input_index, &output_index);
|
||||
ASSERT_NE(input_index, nullptr);
|
||||
ASSERT_NE(output_index, nullptr);
|
||||
ASSERT_EQ(input_index[0], 5);
|
||||
ASSERT_EQ(input_index[1], 6);
|
||||
ASSERT_EQ(output_index[0], 7);
|
||||
ASSERT_EQ(output_index[1], 8);
|
||||
ASSERT_EQ(len, static_cast<size_t>(2));
|
||||
mock_gqa.ReleaseAliasMap(input_index, output_index);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue