mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[TVM EP] update set input to remove excess copying inside TVM (#11247)
* update TVM * small fixes * update TVM with new set_input and NDArray API * use set_input instead of set_one_input Co-authored-by: Valery Chernov <valery.chernov@deelvin.com>
This commit is contained in:
parent
084165c748
commit
5ae461ec0a
4 changed files with 16 additions and 8 deletions
|
|
@ -46,7 +46,7 @@
|
|||
"component": {
|
||||
"type": "git",
|
||||
"git": {
|
||||
"commitHash": "ffd5f70370642c909222f9a4cae8400023dacbdc",
|
||||
"commitHash": "fafabc96c1ba1a5f987c2402fcc2ce4d1bad5cc8",
|
||||
"repositoryUrl": "https://github.com/apache/tvm.git"
|
||||
},
|
||||
"comments": "needed for TVM EP"
|
||||
|
|
|
|||
2
cmake/external/tvm.cmake
vendored
2
cmake/external/tvm.cmake
vendored
|
|
@ -4,7 +4,7 @@ if (onnxruntime_USE_TVM)
|
|||
FetchContent_Declare(
|
||||
tvm
|
||||
GIT_REPOSITORY https://github.com/apache/tvm.git
|
||||
GIT_TAG ffd5f70370642c909222f9a4cae8400023dacbdc
|
||||
GIT_TAG fafabc96c1ba1a5f987c2402fcc2ce4d1bad5cc8
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(tvm)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ void* TVMAllocator::Alloc(size_t size) {
|
|||
void* p = nullptr;
|
||||
if (size > 0) {
|
||||
DLDataType dl_type{kDLInt, 8, 1};
|
||||
int err = TVMDeviceAllocDataSpace(ctx, size, 128, dl_type, (void**)&p);
|
||||
int err = TVMDeviceAllocDataSpace(ctx, size, TVM_ALLOC_ALIGN, dl_type, (void**)&p);
|
||||
CHECK_EQ(err, 0);
|
||||
return p;
|
||||
}
|
||||
|
|
@ -22,7 +22,7 @@ void* TVMAllocator::Alloc(size_t size) {
|
|||
}
|
||||
|
||||
void TVMAllocator::Free(void* p) {
|
||||
TVMDeviceFreeDataSpace(ctx, p);
|
||||
TVMDeviceFreeDataSpace(ctx, p);
|
||||
}
|
||||
|
||||
} // namespace tvm
|
||||
|
|
|
|||
|
|
@ -70,11 +70,19 @@ void TVM_VM_SetInputs(TvmModule& mod,
|
|||
std::vector<size_t>& inds,
|
||||
std::vector<DLTensor>& inputs)
|
||||
{
|
||||
TvmPackedFunc set_input = mod.GetFunction("set_one_input", false);
|
||||
for (size_t i = 0; i < inds.size(); ++i)
|
||||
{
|
||||
set_input("main", inds[i], &inputs[i]);
|
||||
size_t num_total_args = inputs.size() + 1;
|
||||
std::vector<TVMValue> tvm_values(num_total_args);
|
||||
std::vector<int> tvm_type_codes(num_total_args);
|
||||
::tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data());
|
||||
const std::string func_name = "main";
|
||||
setter(0, func_name.c_str());
|
||||
for (size_t k = 0; k < num_total_args - 1; ++k) {
|
||||
setter(inds[k]+1, &inputs[k]);
|
||||
}
|
||||
|
||||
TvmPackedFunc set_input = mod.GetFunction("set_input", false);
|
||||
::tvm::runtime::TVMRetValue rv;
|
||||
set_input.CallPacked(::tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), num_total_args), &rv);
|
||||
}
|
||||
|
||||
void TVMGetOutputs(TvmModule& mod,
|
||||
|
|
|
|||
Loading…
Reference in a new issue