add API function GetAliasMap and ReleaseAliasMap in OrtCustomOp (#20145)

### Description
<!-- Describe your changes. -->
Add API function GetAliasMap and ReleaseAliasMap in OrtCustomOp 


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Add API function GetAliasMap and ReleaseAliasMap in OrtCustomOp
This commit is contained in:
cao lei 2024-03-29 13:49:56 -07:00 committed by GitHub
parent 8396845806
commit 604b284261
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 43 additions and 0 deletions

View file

@ -4738,6 +4738,10 @@ struct OrtCustomOp {
// Release the pointer input_index and output_index allocated from GetMayInplace() function.
// If GetMayInplace() is defined, this function MUST be defined as well.
void(ORT_API_CALL* ReleaseMayInplace)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index);
// Same as GetMayInplace() and ReleaseMayInplace()
size_t(ORT_API_CALL* GetAliasMap)(_Out_ int** input_index, _Out_ int** output_index);
void(ORT_API_CALL* ReleaseAliasMap)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index);
};
/*

View file

@ -2304,6 +2304,8 @@ struct CustomOpBase : OrtCustomOp {
OrtCustomOp::GetMayInplace = nullptr;
OrtCustomOp::ReleaseMayInplace = nullptr;
OrtCustomOp::GetAliasMap = nullptr;
OrtCustomOp::ReleaseAliasMap = nullptr;
}
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider

View file

@ -865,6 +865,8 @@ struct OrtLiteCustomOp : public OrtCustomOp {
OrtCustomOp::GetMayInplace = {};
OrtCustomOp::ReleaseMayInplace = {};
OrtCustomOp::GetAliasMap = {};
OrtCustomOp::ReleaseAliasMap = {};
}
const std::string op_name_;

View file

@ -875,6 +875,16 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust
}
}
if (op->version >= 18 && op->GetAliasMap != nullptr) {
int* input_index = nullptr;
int* output_index = nullptr;
size_t len = op->GetAliasMap(&input_index, &output_index);
if (len > 0) {
for (size_t i = 0; i < len; i++) def_builder.Alias(input_index[i], output_index[i]);
op->ReleaseAliasMap(input_index, output_index);
}
}
KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info,
std::unique_ptr<OpKernel>& out) -> Status {
out = std::make_unique<CustomOpKernel>(info, *op);

View file

@ -4039,6 +4039,20 @@ struct MockGQA : public OrtCustomOp {
free(input_index);
free(output_index);
};
OrtCustomOp::GetAliasMap = [](int** input_index, int** output_index) {
size_t ret = 2;
*input_index = static_cast<int*>(malloc(ret * sizeof(int)));
(*input_index)[0] = 5;
(*input_index)[1] = 6;
*output_index = static_cast<int*>(malloc(ret * sizeof(int)));
(*output_index)[0] = 7;
(*output_index)[1] = 8;
return ret;
};
OrtCustomOp::ReleaseAliasMap = [](int* input_index, int* output_index) {
free(input_index);
free(output_index);
};
}
};
@ -4055,4 +4069,15 @@ TEST(CApiTest, OrtCustomOp_GetInPlace) {
ASSERT_EQ(output_index[1], 2);
ASSERT_EQ(len, static_cast<size_t>(2));
mock_gqa.ReleaseMayInplace(input_index, output_index);
input_index = output_index = nullptr;
len = mock_gqa.GetAliasMap(&input_index, &output_index);
ASSERT_NE(input_index, nullptr);
ASSERT_NE(output_index, nullptr);
ASSERT_EQ(input_index[0], 5);
ASSERT_EQ(input_index[1], 6);
ASSERT_EQ(output_index[0], 7);
ASSERT_EQ(output_index[1], 8);
ASSERT_EQ(len, static_cast<size_t>(2));
mock_gqa.ReleaseAliasMap(input_index, output_index);
}