diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 95b5fe849a..4f062efb09 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -953,6 +953,7 @@ struct ProviderHost { virtual const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const = 0; virtual const NodeArg* Graph__GetNodeArg(const Graph* p, const std::string& name) const = 0; virtual IOnnxRuntimeOpSchemaCollectionPtr Graph__GetSchemaRegistry(const Graph* p) const = 0; + virtual bool Graph__SetOpSchemaFromRegistryForNode(Graph* p, Node& node) = 0; // GraphViewer virtual void GraphViewer__operator_delete(GraphViewer* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 5b052bdc24..63ef36b0a7 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1013,6 +1013,7 @@ struct Graph final { Node* GetNode(NodeIndex node_index) noexcept { return g_host->Graph__GetNode(this, node_index); } const NodeArg* GetNodeArg(const std::string& name) const { return g_host->Graph__GetNodeArg(this, name); } IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return g_host->Graph__GetSchemaRegistry(this); } + bool SetOpSchemaFromRegistryForNode(Node& node) { return g_host->Graph__SetOpSchemaFromRegistryForNode(this, node); } PROVIDER_DISALLOW_ALL(Graph) }; diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index 683a7c6e2a..191d26f3ab 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -221,7 +221,11 @@ Model* model_clone(const Model& original_model, int64_t external_data_threshold) } } auto ret = Model::Create(std::move(*model_proto), file_path, &local_registries, logger); - auto status = ret->MainGraph().Resolve(); + auto& graph = ret->MainGraph(); + for (auto node : graph.Nodes()) { + graph.SetOpSchemaFromRegistryForNode(*graph.GetNode(node->Index())); + } + auto status = graph.Resolve(); vai_assert(status.IsOK(), status.ErrorMessage()); return ret.release(); } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 1cb39f0521..8e807c3751 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1143,6 +1143,7 @@ struct ProviderHostImpl : ProviderHost { const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const override { return p->GetNode(node_index); } const NodeArg* Graph__GetNodeArg(const Graph* p, const std::string& name) const override { return p->GetNodeArg(name); } IOnnxRuntimeOpSchemaCollectionPtr Graph__GetSchemaRegistry(const Graph* p) const override { return p->GetSchemaRegistry(); } + bool Graph__SetOpSchemaFromRegistryForNode(Graph* p, Node& node) override { return p->SetOpSchemaFromRegistryForNode(node); } // GraphViewer (wrapped) void GraphViewer__operator_delete(GraphViewer* p) override { delete p; }