From 85ec13f58df7ecec3c81b46b11ec90fa8902f364 Mon Sep 17 00:00:00 2001 From: "Tang, Cheng" Date: Mon, 7 Jan 2019 10:55:24 -0800 Subject: [PATCH] fix tvm break (#282) --- onnxruntime/core/codegen/tvm/tvm_compiler.cc | 36 ++++++++++++-------- onnxruntime/test/tvm/tvm_basic_test.cc | 2 +- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/codegen/tvm/tvm_compiler.cc b/onnxruntime/core/codegen/tvm/tvm_compiler.cc index 133b64780f..fba99dfeda 100644 --- a/onnxruntime/core/codegen/tvm/tvm_compiler.cc +++ b/onnxruntime/core/codegen/tvm/tvm_compiler.cc @@ -35,52 +35,58 @@ class IdGenerator { }; // This is a special compiler step for the test case that sum two 1-D tensors -static void Compile1DAddToTVM(const onnxruntime::Node& node, std::unordered_map& tvm_tensors, onnxruntime::ProviderType execution_provider_type, IdGenerator& generator) { +static void Compile1DAddToTVM(const onnxruntime::Node& node, std::unordered_map& tvm_tensors, onnxruntime::ProviderType execution_provider_type, IdGenerator& generator) { ORT_ENFORCE(node.OpType() == "Add"); tvm::Array shape; shape.push_back(tvm::var("n1")); tvm::Tensor t1, t2; - auto it = tvm_tensors.find(node.InputDefs()[0]); + auto it = tvm_tensors.find(node.InputDefs()[0]->Name()); if (it == tvm_tensors.end()) { - tvm_tensors[node.InputDefs()[0]] = TVMGraph::TensorDescriptor( + tvm_tensors[node.InputDefs()[0]->Name()] = TVMGraph::TensorDescriptor( DataTypeImpl::TypeFromProto(*node.InputDefs()[0]->TypeAsProto()), execution_provider_type, tvm::placeholder(shape, tvm::Float(64), "T" + std::to_string(generator.GetNext()))); } - t1 = tvm_tensors[node.InputDefs()[0]].tvm_tensor_; - it = tvm_tensors.find(node.InputDefs()[1]); + t1 = tvm_tensors[node.InputDefs()[0]->Name()].tvm_tensor_; + it = tvm_tensors.find(node.InputDefs()[1]->Name()); if (it == tvm_tensors.end()) { - tvm_tensors[node.InputDefs()[1]] = TVMGraph::TensorDescriptor( + tvm_tensors[node.InputDefs()[1]->Name()] = TVMGraph::TensorDescriptor( DataTypeImpl::TypeFromProto(*node.InputDefs()[1]->TypeAsProto()), execution_provider_type, tvm::placeholder(shape, tvm::Float(64), "T" + std::to_string(generator.GetNext()))); } - t2 = tvm_tensors[node.InputDefs()[1]].tvm_tensor_; + t2 = tvm_tensors[node.InputDefs()[1]->Name()].tvm_tensor_; - tvm_tensors[node.OutputDefs()[0]] = TVMGraph::TensorDescriptor( + tvm_tensors[node.OutputDefs()[0]->Name()] = TVMGraph::TensorDescriptor( DataTypeImpl::TypeFromProto(*node.InputDefs()[1]->TypeAsProto()), execution_provider_type, - tvm::compute(t1->shape, [&t1, &t2](tvm::Expr i) { - return t1[i] + t2[i]; - }, - "T" + std::to_string(generator.GetNext()))); + tvm::compute( + t1->shape, [&t1, &t2](tvm::Expr i) { + return t1[i] + t2[i]; + }, + "T" + std::to_string(generator.GetNext()))); } TVMGraph CompileToTVM(const onnxruntime::Graph& graph, onnxruntime::ProviderType execution_provider_type) { TVMGraph result; - std::unordered_map tvm_tensors; + std::unordered_map tvm_tensors; IdGenerator generator; for (auto& node : graph.Nodes()) { Compile1DAddToTVM(node, tvm_tensors, execution_provider_type, generator); } for (auto& input : graph.GetInputs()) { - result.inputs_.push_back(tvm_tensors[input]); + result.inputs_.push_back(tvm_tensors[input->Name()]); + } + + // check initializer + for (auto& initializer : graph.GetAllInitializedTensors()) { + result.inputs_.push_back(tvm_tensors[initializer.first]); } auto& output = graph.GetOutputs()[0]; - result.outputs_.push_back(tvm_tensors[output]); + result.outputs_.push_back(tvm_tensors[output->Name()]); return result; } } // namespace onnxruntime diff --git a/onnxruntime/test/tvm/tvm_basic_test.cc b/onnxruntime/test/tvm/tvm_basic_test.cc index 2c2c8ef8df..6924232797 100644 --- a/onnxruntime/test/tvm/tvm_basic_test.cc +++ b/onnxruntime/test/tvm/tvm_basic_test.cc @@ -205,7 +205,7 @@ class FuseExecutionProviderX : public CPUExecutionProvider { compute_info.release_state_func = [](FunctionState state) { if (state) - delete state; + delete static_cast(state); }; //we use lambda to capture the tvm model, so we can use it to get the funciton.