mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
fix tvm break (#282)
This commit is contained in:
parent
ec2cf59baa
commit
85ec13f58d
2 changed files with 22 additions and 16 deletions
|
|
@ -35,52 +35,58 @@ class IdGenerator {
|
|||
};
|
||||
|
||||
// This is a special compiler step for the test case that sum two 1-D tensors
|
||||
static void Compile1DAddToTVM(const onnxruntime::Node& node, std::unordered_map<const onnxruntime::NodeArg*, TVMGraph::TensorDescriptor>& tvm_tensors, onnxruntime::ProviderType execution_provider_type, IdGenerator& generator) {
|
||||
static void Compile1DAddToTVM(const onnxruntime::Node& node, std::unordered_map<std::string, TVMGraph::TensorDescriptor>& tvm_tensors, onnxruntime::ProviderType execution_provider_type, IdGenerator& generator) {
|
||||
ORT_ENFORCE(node.OpType() == "Add");
|
||||
tvm::Array<tvm::Expr> shape;
|
||||
shape.push_back(tvm::var("n1"));
|
||||
|
||||
tvm::Tensor t1, t2;
|
||||
auto it = tvm_tensors.find(node.InputDefs()[0]);
|
||||
auto it = tvm_tensors.find(node.InputDefs()[0]->Name());
|
||||
if (it == tvm_tensors.end()) {
|
||||
tvm_tensors[node.InputDefs()[0]] = TVMGraph::TensorDescriptor(
|
||||
tvm_tensors[node.InputDefs()[0]->Name()] = TVMGraph::TensorDescriptor(
|
||||
DataTypeImpl::TypeFromProto(*node.InputDefs()[0]->TypeAsProto()),
|
||||
execution_provider_type,
|
||||
tvm::placeholder(shape, tvm::Float(64), "T" + std::to_string(generator.GetNext())));
|
||||
}
|
||||
t1 = tvm_tensors[node.InputDefs()[0]].tvm_tensor_;
|
||||
it = tvm_tensors.find(node.InputDefs()[1]);
|
||||
t1 = tvm_tensors[node.InputDefs()[0]->Name()].tvm_tensor_;
|
||||
it = tvm_tensors.find(node.InputDefs()[1]->Name());
|
||||
if (it == tvm_tensors.end()) {
|
||||
tvm_tensors[node.InputDefs()[1]] = TVMGraph::TensorDescriptor(
|
||||
tvm_tensors[node.InputDefs()[1]->Name()] = TVMGraph::TensorDescriptor(
|
||||
DataTypeImpl::TypeFromProto(*node.InputDefs()[1]->TypeAsProto()),
|
||||
execution_provider_type,
|
||||
tvm::placeholder(shape, tvm::Float(64), "T" + std::to_string(generator.GetNext())));
|
||||
}
|
||||
t2 = tvm_tensors[node.InputDefs()[1]].tvm_tensor_;
|
||||
t2 = tvm_tensors[node.InputDefs()[1]->Name()].tvm_tensor_;
|
||||
|
||||
tvm_tensors[node.OutputDefs()[0]] = TVMGraph::TensorDescriptor(
|
||||
tvm_tensors[node.OutputDefs()[0]->Name()] = TVMGraph::TensorDescriptor(
|
||||
DataTypeImpl::TypeFromProto(*node.InputDefs()[1]->TypeAsProto()),
|
||||
execution_provider_type,
|
||||
tvm::compute(t1->shape, [&t1, &t2](tvm::Expr i) {
|
||||
return t1[i] + t2[i];
|
||||
},
|
||||
"T" + std::to_string(generator.GetNext())));
|
||||
tvm::compute(
|
||||
t1->shape, [&t1, &t2](tvm::Expr i) {
|
||||
return t1[i] + t2[i];
|
||||
},
|
||||
"T" + std::to_string(generator.GetNext())));
|
||||
}
|
||||
|
||||
TVMGraph CompileToTVM(const onnxruntime::Graph& graph, onnxruntime::ProviderType execution_provider_type) {
|
||||
TVMGraph result;
|
||||
std::unordered_map<const onnxruntime::NodeArg*, TVMGraph::TensorDescriptor> tvm_tensors;
|
||||
std::unordered_map<std::string, TVMGraph::TensorDescriptor> tvm_tensors;
|
||||
IdGenerator generator;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
Compile1DAddToTVM(node, tvm_tensors, execution_provider_type, generator);
|
||||
}
|
||||
|
||||
for (auto& input : graph.GetInputs()) {
|
||||
result.inputs_.push_back(tvm_tensors[input]);
|
||||
result.inputs_.push_back(tvm_tensors[input->Name()]);
|
||||
}
|
||||
|
||||
// check initializer
|
||||
for (auto& initializer : graph.GetAllInitializedTensors()) {
|
||||
result.inputs_.push_back(tvm_tensors[initializer.first]);
|
||||
}
|
||||
|
||||
auto& output = graph.GetOutputs()[0];
|
||||
result.outputs_.push_back(tvm_tensors[output]);
|
||||
result.outputs_.push_back(tvm_tensors[output->Name()]);
|
||||
return result;
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ class FuseExecutionProviderX : public CPUExecutionProvider {
|
|||
|
||||
compute_info.release_state_func = [](FunctionState state) {
|
||||
if (state)
|
||||
delete state;
|
||||
delete static_cast<TVMFuncState*>(state);
|
||||
};
|
||||
|
||||
//we use lambda to capture the tvm model, so we can use it to get the funciton.
|
||||
|
|
|
|||
Loading…
Reference in a new issue