mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
[VitisAI] Bug fixes in model_clone (#21950)
### Description <!-- Describe your changes. --> VitisAI bug fixes in model clone ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Co-authored-by: Zhenze Wang <zhenzew@xilinx.com>
This commit is contained in:
parent
cbf3c50d75
commit
bf8a8e7e36
4 changed files with 8 additions and 1 deletions
|
|
@ -953,6 +953,7 @@ struct ProviderHost {
|
|||
virtual const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const = 0;
|
||||
virtual const NodeArg* Graph__GetNodeArg(const Graph* p, const std::string& name) const = 0;
|
||||
virtual IOnnxRuntimeOpSchemaCollectionPtr Graph__GetSchemaRegistry(const Graph* p) const = 0;
|
||||
virtual bool Graph__SetOpSchemaFromRegistryForNode(Graph* p, Node& node) = 0;
|
||||
|
||||
// GraphViewer
|
||||
virtual void GraphViewer__operator_delete(GraphViewer* p) = 0;
|
||||
|
|
|
|||
|
|
@ -1013,6 +1013,7 @@ struct Graph final {
|
|||
Node* GetNode(NodeIndex node_index) noexcept { return g_host->Graph__GetNode(this, node_index); }
|
||||
const NodeArg* GetNodeArg(const std::string& name) const { return g_host->Graph__GetNodeArg(this, name); }
|
||||
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return g_host->Graph__GetSchemaRegistry(this); }
|
||||
bool SetOpSchemaFromRegistryForNode(Node& node) { return g_host->Graph__SetOpSchemaFromRegistryForNode(this, node); }
|
||||
|
||||
PROVIDER_DISALLOW_ALL(Graph)
|
||||
};
|
||||
|
|
|
|||
|
|
@ -221,7 +221,11 @@ Model* model_clone(const Model& original_model, int64_t external_data_threshold)
|
|||
}
|
||||
}
|
||||
auto ret = Model::Create(std::move(*model_proto), file_path, &local_registries, logger);
|
||||
auto status = ret->MainGraph().Resolve();
|
||||
auto& graph = ret->MainGraph();
|
||||
for (auto node : graph.Nodes()) {
|
||||
graph.SetOpSchemaFromRegistryForNode(*graph.GetNode(node->Index()));
|
||||
}
|
||||
auto status = graph.Resolve();
|
||||
vai_assert(status.IsOK(), status.ErrorMessage());
|
||||
return ret.release();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1143,6 +1143,7 @@ struct ProviderHostImpl : ProviderHost {
|
|||
const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const override { return p->GetNode(node_index); }
|
||||
const NodeArg* Graph__GetNodeArg(const Graph* p, const std::string& name) const override { return p->GetNodeArg(name); }
|
||||
IOnnxRuntimeOpSchemaCollectionPtr Graph__GetSchemaRegistry(const Graph* p) const override { return p->GetSchemaRegistry(); }
|
||||
bool Graph__SetOpSchemaFromRegistryForNode(Graph* p, Node& node) override { return p->SetOpSchemaFromRegistryForNode(node); }
|
||||
|
||||
// GraphViewer (wrapped)
|
||||
void GraphViewer__operator_delete(GraphViewer* p) override { delete p; }
|
||||
|
|
|
|||
Loading…
Reference in a new issue