diff --git a/cgmanifests/cgmanifest.json b/cgmanifests/cgmanifest.json index 379ff6921c..a1b44bc28e 100644 --- a/cgmanifests/cgmanifest.json +++ b/cgmanifests/cgmanifest.json @@ -46,7 +46,7 @@ "component": { "type": "git", "git": { - "commitHash": "ffd5f70370642c909222f9a4cae8400023dacbdc", + "commitHash": "fafabc96c1ba1a5f987c2402fcc2ce4d1bad5cc8", "repositoryUrl": "https://github.com/apache/tvm.git" }, "comments": "needed for TVM EP" diff --git a/cmake/external/tvm.cmake b/cmake/external/tvm.cmake index 3f425a0938..82d7114a66 100644 --- a/cmake/external/tvm.cmake +++ b/cmake/external/tvm.cmake @@ -4,7 +4,7 @@ if (onnxruntime_USE_TVM) FetchContent_Declare( tvm GIT_REPOSITORY https://github.com/apache/tvm.git - GIT_TAG ffd5f70370642c909222f9a4cae8400023dacbdc + GIT_TAG fafabc96c1ba1a5f987c2402fcc2ce4d1bad5cc8 ) FetchContent_GetProperties(tvm) diff --git a/onnxruntime/core/providers/tvm/tvm_allocator.cc b/onnxruntime/core/providers/tvm/tvm_allocator.cc index ef06e1f59a..8793676421 100644 --- a/onnxruntime/core/providers/tvm/tvm_allocator.cc +++ b/onnxruntime/core/providers/tvm/tvm_allocator.cc @@ -14,7 +14,7 @@ void* TVMAllocator::Alloc(size_t size) { void* p = nullptr; if (size > 0) { DLDataType dl_type{kDLInt, 8, 1}; - int err = TVMDeviceAllocDataSpace(ctx, size, 128, dl_type, (void**)&p); + int err = TVMDeviceAllocDataSpace(ctx, size, TVM_ALLOC_ALIGN, dl_type, (void**)&p); CHECK_EQ(err, 0); return p; } @@ -22,7 +22,7 @@ void* TVMAllocator::Alloc(size_t size) { } void TVMAllocator::Free(void* p) { - TVMDeviceFreeDataSpace(ctx, p); + TVMDeviceFreeDataSpace(ctx, p); } } // namespace tvm diff --git a/onnxruntime/core/providers/tvm/tvm_api.cc b/onnxruntime/core/providers/tvm/tvm_api.cc index ff61c6c43d..0c5e07b302 100644 --- a/onnxruntime/core/providers/tvm/tvm_api.cc +++ b/onnxruntime/core/providers/tvm/tvm_api.cc @@ -70,11 +70,19 @@ void TVM_VM_SetInputs(TvmModule& mod, std::vector& inds, std::vector& inputs) { - TvmPackedFunc set_input = mod.GetFunction("set_one_input", false); - for (size_t i = 0; i < inds.size(); ++i) - { - set_input("main", inds[i], &inputs[i]); + size_t num_total_args = inputs.size() + 1; + std::vector tvm_values(num_total_args); + std::vector tvm_type_codes(num_total_args); + ::tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); + const std::string func_name = "main"; + setter(0, func_name.c_str()); + for (size_t k = 0; k < num_total_args - 1; ++k) { + setter(inds[k]+1, &inputs[k]); } + + TvmPackedFunc set_input = mod.GetFunction("set_input", false); + ::tvm::runtime::TVMRetValue rv; + set_input.CallPacked(::tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), num_total_args), &rv); } void TVMGetOutputs(TvmModule& mod,