[TVM EP] update set input to remove excess copying inside TVM (#11247)

* update TVM

* small fixes

* update TVM with new set_input and NDArray API

* use set_input instead of set_one_input

Co-authored-by: Valery Chernov <valery.chernov@deelvin.com>
This commit is contained in:
Valery Chernov 2022-05-05 15:25:02 +03:00 committed by GitHub
parent 084165c748
commit 5ae461ec0a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 8 deletions

View file

@ -46,7 +46,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "ffd5f70370642c909222f9a4cae8400023dacbdc",
"commitHash": "fafabc96c1ba1a5f987c2402fcc2ce4d1bad5cc8",
"repositoryUrl": "https://github.com/apache/tvm.git"
},
"comments": "needed for TVM EP"

View file

@ -4,7 +4,7 @@ if (onnxruntime_USE_TVM)
FetchContent_Declare(
tvm
GIT_REPOSITORY https://github.com/apache/tvm.git
GIT_TAG ffd5f70370642c909222f9a4cae8400023dacbdc
GIT_TAG fafabc96c1ba1a5f987c2402fcc2ce4d1bad5cc8
)
FetchContent_GetProperties(tvm)

View file

@ -14,7 +14,7 @@ void* TVMAllocator::Alloc(size_t size) {
void* p = nullptr;
if (size > 0) {
DLDataType dl_type{kDLInt, 8, 1};
int err = TVMDeviceAllocDataSpace(ctx, size, 128, dl_type, (void**)&p);
int err = TVMDeviceAllocDataSpace(ctx, size, TVM_ALLOC_ALIGN, dl_type, (void**)&p);
CHECK_EQ(err, 0);
return p;
}
@ -22,7 +22,7 @@ void* TVMAllocator::Alloc(size_t size) {
}
void TVMAllocator::Free(void* p) {
TVMDeviceFreeDataSpace(ctx, p);
TVMDeviceFreeDataSpace(ctx, p);
}
} // namespace tvm

View file

@ -70,11 +70,19 @@ void TVM_VM_SetInputs(TvmModule& mod,
std::vector<size_t>& inds,
std::vector<DLTensor>& inputs)
{
TvmPackedFunc set_input = mod.GetFunction("set_one_input", false);
for (size_t i = 0; i < inds.size(); ++i)
{
set_input("main", inds[i], &inputs[i]);
size_t num_total_args = inputs.size() + 1;
std::vector<TVMValue> tvm_values(num_total_args);
std::vector<int> tvm_type_codes(num_total_args);
::tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data());
const std::string func_name = "main";
setter(0, func_name.c_str());
for (size_t k = 0; k < num_total_args - 1; ++k) {
setter(inds[k]+1, &inputs[k]);
}
TvmPackedFunc set_input = mod.GetFunction("set_input", false);
::tvm::runtime::TVMRetValue rv;
set_input.CallPacked(::tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), num_total_args), &rv);
}
void TVMGetOutputs(TvmModule& mod,