fix tvm break (#282)

This commit is contained in:
Tang, Cheng 2019-01-07 10:55:24 -08:00 committed by GitHub
parent ec2cf59baa
commit 85ec13f58d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 16 deletions

View file

@ -35,52 +35,58 @@ class IdGenerator {
};
// This is a special compiler step for the test case that sum two 1-D tensors
static void Compile1DAddToTVM(const onnxruntime::Node& node, std::unordered_map<const onnxruntime::NodeArg*, TVMGraph::TensorDescriptor>& tvm_tensors, onnxruntime::ProviderType execution_provider_type, IdGenerator& generator) {
static void Compile1DAddToTVM(const onnxruntime::Node& node, std::unordered_map<std::string, TVMGraph::TensorDescriptor>& tvm_tensors, onnxruntime::ProviderType execution_provider_type, IdGenerator& generator) {
ORT_ENFORCE(node.OpType() == "Add");
tvm::Array<tvm::Expr> shape;
shape.push_back(tvm::var("n1"));
tvm::Tensor t1, t2;
auto it = tvm_tensors.find(node.InputDefs()[0]);
auto it = tvm_tensors.find(node.InputDefs()[0]->Name());
if (it == tvm_tensors.end()) {
tvm_tensors[node.InputDefs()[0]] = TVMGraph::TensorDescriptor(
tvm_tensors[node.InputDefs()[0]->Name()] = TVMGraph::TensorDescriptor(
DataTypeImpl::TypeFromProto(*node.InputDefs()[0]->TypeAsProto()),
execution_provider_type,
tvm::placeholder(shape, tvm::Float(64), "T" + std::to_string(generator.GetNext())));
}
t1 = tvm_tensors[node.InputDefs()[0]].tvm_tensor_;
it = tvm_tensors.find(node.InputDefs()[1]);
t1 = tvm_tensors[node.InputDefs()[0]->Name()].tvm_tensor_;
it = tvm_tensors.find(node.InputDefs()[1]->Name());
if (it == tvm_tensors.end()) {
tvm_tensors[node.InputDefs()[1]] = TVMGraph::TensorDescriptor(
tvm_tensors[node.InputDefs()[1]->Name()] = TVMGraph::TensorDescriptor(
DataTypeImpl::TypeFromProto(*node.InputDefs()[1]->TypeAsProto()),
execution_provider_type,
tvm::placeholder(shape, tvm::Float(64), "T" + std::to_string(generator.GetNext())));
}
t2 = tvm_tensors[node.InputDefs()[1]].tvm_tensor_;
t2 = tvm_tensors[node.InputDefs()[1]->Name()].tvm_tensor_;
tvm_tensors[node.OutputDefs()[0]] = TVMGraph::TensorDescriptor(
tvm_tensors[node.OutputDefs()[0]->Name()] = TVMGraph::TensorDescriptor(
DataTypeImpl::TypeFromProto(*node.InputDefs()[1]->TypeAsProto()),
execution_provider_type,
tvm::compute(t1->shape, [&t1, &t2](tvm::Expr i) {
return t1[i] + t2[i];
},
"T" + std::to_string(generator.GetNext())));
tvm::compute(
t1->shape, [&t1, &t2](tvm::Expr i) {
return t1[i] + t2[i];
},
"T" + std::to_string(generator.GetNext())));
}
TVMGraph CompileToTVM(const onnxruntime::Graph& graph, onnxruntime::ProviderType execution_provider_type) {
TVMGraph result;
std::unordered_map<const onnxruntime::NodeArg*, TVMGraph::TensorDescriptor> tvm_tensors;
std::unordered_map<std::string, TVMGraph::TensorDescriptor> tvm_tensors;
IdGenerator generator;
for (auto& node : graph.Nodes()) {
Compile1DAddToTVM(node, tvm_tensors, execution_provider_type, generator);
}
for (auto& input : graph.GetInputs()) {
result.inputs_.push_back(tvm_tensors[input]);
result.inputs_.push_back(tvm_tensors[input->Name()]);
}
// check initializer
for (auto& initializer : graph.GetAllInitializedTensors()) {
result.inputs_.push_back(tvm_tensors[initializer.first]);
}
auto& output = graph.GetOutputs()[0];
result.outputs_.push_back(tvm_tensors[output]);
result.outputs_.push_back(tvm_tensors[output->Name()]);
return result;
}
} // namespace onnxruntime

View file

@ -205,7 +205,7 @@ class FuseExecutionProviderX : public CPUExecutionProvider {
compute_info.release_state_func = [](FunctionState state) {
if (state)
delete state;
delete static_cast<TVMFuncState*>(state);
};
//we use lambda to capture the tvm model, so we can use it to get the funciton.