diff --git a/include/onnxruntime/core/graph/rewrite_rule.h b/include/onnxruntime/core/graph/rewrite_rule.h index d98c3779d9..db979bdfd2 100644 --- a/include/onnxruntime/core/graph/rewrite_rule.h +++ b/include/onnxruntime/core/graph/rewrite_rule.h @@ -8,58 +8,6 @@ namespace onnxruntime { -/** -@class GraphEditor -The API for Graph rewrite rules. -*/ -class GraphEditor { - public: - explicit GraphEditor(Graph& graph) noexcept : graph_{graph} {} - - /** Add a node to this Graph */ - Node& AddNode(const std::string& name, - const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, - const std::string& domain = "") { - return graph_.AddNode(name, op_type, description, - input_args, output_args, nullptr, domain); - } - - /** Copy an existing node into the Graph. */ - Node& AddNode(const Node& other) { - return graph_.AddNode(other); - } - - /** Remove a node from the Graph. */ - bool RemoveNode(NodeIndex node_index) { - return graph_.RemoveNode(node_index); - } - - /** Add a control edge between two Nodes in the Graph - The node does not consume any data output by , so there is no input/output edge between them, - but dst must executed after src so a control edge is required. - @param src NodeIndex from the Graph of the Node which must execute first. - @param dst NodeIndex from the Graph of the Node which must execute after src. - */ - bool AddControlEdge(NodeIndex src, NodeIndex dst) { - return graph_.AddControlEdge(src, dst); - } - - /** Resolve the Graph. - @returns Status with success or error information. - @remarks Resolve must be called after modifying the Graph is completed. */ - common::Status Resolve() { - return graph_.Resolve(); - } - - private: - ONNXRUNTIME_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphEditor); - - Graph& graph_; -}; - /** @class RewriteRule @@ -94,8 +42,8 @@ class RewriteRule { @param[in] node The Node to apply the rewrite to. @param[out] modified Set to indicate whether the node was modified or not. @returns Status indicating success or providing error information */ - common::Status CheckConditionAndApply(GraphEditor& graph_editor, Node& node, bool& modified) { - return SatisfyCondition(node) ? Apply(graph_editor, node, modified) : Status::OK(); + common::Status CheckConditionAndApply(Graph& graph, Node& node, bool& modified) { + return SatisfyCondition(node) ? Apply(graph, node, modified) : Status::OK(); } private: @@ -115,6 +63,6 @@ class RewriteRule { The transformation happens in-place. The return-value of node may be different from the input-value due to rewriting. The value of "modified" indicates if the graph was modified or not. */ - virtual common::Status Apply(GraphEditor& graph_editor, Node& node, bool& modified) = 0; + virtual common::Status Apply(Graph& graph, Node& node, bool& modified) = 0; }; } // namespace onnxruntime diff --git a/onnxruntime/core/graph/function_inliner.h b/onnxruntime/core/graph/function_inliner.h index 1d277836c7..e3b5f927db 100644 --- a/onnxruntime/core/graph/function_inliner.h +++ b/onnxruntime/core/graph/function_inliner.h @@ -18,7 +18,7 @@ class FunctionInliner : public onnxruntime::RewriteRule { FunctionInliner(const std::string& name, const std::string& desc) : RewriteRule(name, desc) {} - Status Apply(onnxruntime::GraphEditor /*graph_editor*/, onnxruntime::Node* /*node*/, bool* /*modified*/) override { + Status Apply(onnxruntime::Graph/*graph_editor*/, onnxruntime::Node* /*node*/, bool* /*modified*/) override { return Status::OK(); } }; diff --git a/onnxruntime/core/graph/graph_transformer.cc b/onnxruntime/core/graph/graph_transformer.cc index 54fc8b96a8..bcb5fa03bf 100644 --- a/onnxruntime/core/graph/graph_transformer.cc +++ b/onnxruntime/core/graph/graph_transformer.cc @@ -20,7 +20,6 @@ Status TopDownRuleBasedTransformer::Apply(Graph& graph, bool& modified) const { ONNXRUNTIME_RETURN_IF_ERROR(graph.Resolve()); GraphViewer graph_viewer(graph); auto& order = graph_viewer.GetNodesInTopologicalOrder(); - GraphEditor graph_editor(graph); for (NodeIndex i : order) { auto node = graph.GetNode(i); @@ -34,7 +33,7 @@ Status TopDownRuleBasedTransformer::Apply(Graph& graph, bool& modified) const { continue; for (const auto& rule : *rules) { - ONNXRUNTIME_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph_editor, *node, modified)); + ONNXRUNTIME_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, *node, modified)); } } diff --git a/onnxruntime/core/graph/identity_elimination.cc b/onnxruntime/core/graph/identity_elimination.cc index af9bf2c7be..969d136626 100644 --- a/onnxruntime/core/graph/identity_elimination.cc +++ b/onnxruntime/core/graph/identity_elimination.cc @@ -9,7 +9,7 @@ namespace onnxruntime { -Status EliminateIdentity::Apply(GraphEditor& graph_editor, Node& node, bool& modified) { +Status EliminateIdentity::Apply(Graph& graph_editor, Node& node, bool& modified) { std::map replacement_defs; auto id_input = node.InputDefs()[0]; auto id_output = node.OutputDefs()[0]; diff --git a/onnxruntime/core/graph/identity_elimination.h b/onnxruntime/core/graph/identity_elimination.h index d1c7ab240f..57b16b5655 100644 --- a/onnxruntime/core/graph/identity_elimination.h +++ b/onnxruntime/core/graph/identity_elimination.h @@ -15,7 +15,7 @@ class EliminateIdentity : public RewriteRule { private: bool SatisfyCondition(const Node& node) override; - Status Apply(GraphEditor& graph_editor, Node& node, bool& modified) override; + Status Apply(Graph& graph_editor, Node& node, bool& modified) override; }; } // namespace onnxruntime