[VitisAI] Bug fixes in model_clone (#21950)

### Description
<!-- Describe your changes. -->

VitisAI bug fixes in model clone

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Zhenze Wang <zhenzew@xilinx.com>
This commit is contained in:
zz002 2024-09-05 01:29:17 +08:00 committed by GitHub
parent cbf3c50d75
commit bf8a8e7e36
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 8 additions and 1 deletions

View file

@ -953,6 +953,7 @@ struct ProviderHost {
virtual const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const = 0;
virtual const NodeArg* Graph__GetNodeArg(const Graph* p, const std::string& name) const = 0;
virtual IOnnxRuntimeOpSchemaCollectionPtr Graph__GetSchemaRegistry(const Graph* p) const = 0;
virtual bool Graph__SetOpSchemaFromRegistryForNode(Graph* p, Node& node) = 0;
// GraphViewer
virtual void GraphViewer__operator_delete(GraphViewer* p) = 0;

View file

@ -1013,6 +1013,7 @@ struct Graph final {
Node* GetNode(NodeIndex node_index) noexcept { return g_host->Graph__GetNode(this, node_index); }
const NodeArg* GetNodeArg(const std::string& name) const { return g_host->Graph__GetNodeArg(this, name); }
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return g_host->Graph__GetSchemaRegistry(this); }
bool SetOpSchemaFromRegistryForNode(Node& node) { return g_host->Graph__SetOpSchemaFromRegistryForNode(this, node); }
PROVIDER_DISALLOW_ALL(Graph)
};

View file

@ -221,7 +221,11 @@ Model* model_clone(const Model& original_model, int64_t external_data_threshold)
}
}
auto ret = Model::Create(std::move(*model_proto), file_path, &local_registries, logger);
auto status = ret->MainGraph().Resolve();
auto& graph = ret->MainGraph();
for (auto node : graph.Nodes()) {
graph.SetOpSchemaFromRegistryForNode(*graph.GetNode(node->Index()));
}
auto status = graph.Resolve();
vai_assert(status.IsOK(), status.ErrorMessage());
return ret.release();
}

View file

@ -1143,6 +1143,7 @@ struct ProviderHostImpl : ProviderHost {
const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const override { return p->GetNode(node_index); }
const NodeArg* Graph__GetNodeArg(const Graph* p, const std::string& name) const override { return p->GetNodeArg(name); }
IOnnxRuntimeOpSchemaCollectionPtr Graph__GetSchemaRegistry(const Graph* p) const override { return p->GetSchemaRegistry(); }
bool Graph__SetOpSchemaFromRegistryForNode(Graph* p, Node& node) override { return p->SetOpSchemaFromRegistryForNode(node); }
// GraphViewer (wrapped)
void GraphViewer__operator_delete(GraphViewer* p) override { delete p; }