mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-11 00:49:31 +00:00
remove graph editor since it's not designed as expected to restrict graph access for rewrite rule. (#119)
This commit is contained in:
parent
194d0a98d1
commit
830c341c19
5 changed files with 7 additions and 60 deletions
|
|
@ -8,58 +8,6 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@class GraphEditor
|
||||
The API for Graph rewrite rules.
|
||||
*/
|
||||
class GraphEditor {
|
||||
public:
|
||||
explicit GraphEditor(Graph& graph) noexcept : graph_{graph} {}
|
||||
|
||||
/** Add a node to this Graph */
|
||||
Node& AddNode(const std::string& name,
|
||||
const std::string& op_type,
|
||||
const std::string& description,
|
||||
const std::vector<NodeArg*>& input_args,
|
||||
const std::vector<NodeArg*>& output_args,
|
||||
const std::string& domain = "") {
|
||||
return graph_.AddNode(name, op_type, description,
|
||||
input_args, output_args, nullptr, domain);
|
||||
}
|
||||
|
||||
/** Copy an existing node into the Graph. */
|
||||
Node& AddNode(const Node& other) {
|
||||
return graph_.AddNode(other);
|
||||
}
|
||||
|
||||
/** Remove a node from the Graph. */
|
||||
bool RemoveNode(NodeIndex node_index) {
|
||||
return graph_.RemoveNode(node_index);
|
||||
}
|
||||
|
||||
/** Add a control edge between two Nodes in the Graph
|
||||
The <dst> node does not consume any data output by <src>, so there is no input/output edge between them,
|
||||
but dst must executed after src so a control edge is required.
|
||||
@param src NodeIndex from the Graph of the Node which must execute first.
|
||||
@param dst NodeIndex from the Graph of the Node which must execute after src.
|
||||
*/
|
||||
bool AddControlEdge(NodeIndex src, NodeIndex dst) {
|
||||
return graph_.AddControlEdge(src, dst);
|
||||
}
|
||||
|
||||
/** Resolve the Graph.
|
||||
@returns Status with success or error information.
|
||||
@remarks Resolve must be called after modifying the Graph is completed. */
|
||||
common::Status Resolve() {
|
||||
return graph_.Resolve();
|
||||
}
|
||||
|
||||
private:
|
||||
ONNXRUNTIME_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphEditor);
|
||||
|
||||
Graph& graph_;
|
||||
};
|
||||
|
||||
/**
|
||||
@class RewriteRule
|
||||
|
||||
|
|
@ -94,8 +42,8 @@ class RewriteRule {
|
|||
@param[in] node The Node to apply the rewrite to.
|
||||
@param[out] modified Set to indicate whether the node was modified or not.
|
||||
@returns Status indicating success or providing error information */
|
||||
common::Status CheckConditionAndApply(GraphEditor& graph_editor, Node& node, bool& modified) {
|
||||
return SatisfyCondition(node) ? Apply(graph_editor, node, modified) : Status::OK();
|
||||
common::Status CheckConditionAndApply(Graph& graph, Node& node, bool& modified) {
|
||||
return SatisfyCondition(node) ? Apply(graph, node, modified) : Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -115,6 +63,6 @@ class RewriteRule {
|
|||
The transformation happens in-place. The return-value of node may be different
|
||||
from the input-value due to rewriting.
|
||||
The value of "modified" indicates if the graph was modified or not. */
|
||||
virtual common::Status Apply(GraphEditor& graph_editor, Node& node, bool& modified) = 0;
|
||||
virtual common::Status Apply(Graph& graph, Node& node, bool& modified) = 0;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class FunctionInliner : public onnxruntime::RewriteRule {
|
|||
FunctionInliner(const std::string& name, const std::string& desc)
|
||||
: RewriteRule(name, desc) {}
|
||||
|
||||
Status Apply(onnxruntime::GraphEditor /*graph_editor*/, onnxruntime::Node* /*node*/, bool* /*modified*/) override {
|
||||
Status Apply(onnxruntime::Graph/*graph_editor*/, onnxruntime::Node* /*node*/, bool* /*modified*/) override {
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ Status TopDownRuleBasedTransformer::Apply(Graph& graph, bool& modified) const {
|
|||
ONNXRUNTIME_RETURN_IF_ERROR(graph.Resolve());
|
||||
GraphViewer graph_viewer(graph);
|
||||
auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
GraphEditor graph_editor(graph);
|
||||
|
||||
for (NodeIndex i : order) {
|
||||
auto node = graph.GetNode(i);
|
||||
|
|
@ -34,7 +33,7 @@ Status TopDownRuleBasedTransformer::Apply(Graph& graph, bool& modified) const {
|
|||
continue;
|
||||
|
||||
for (const auto& rule : *rules) {
|
||||
ONNXRUNTIME_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph_editor, *node, modified));
|
||||
ONNXRUNTIME_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, *node, modified));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status EliminateIdentity::Apply(GraphEditor& graph_editor, Node& node, bool& modified) {
|
||||
Status EliminateIdentity::Apply(Graph& graph_editor, Node& node, bool& modified) {
|
||||
std::map<const NodeArg*, NodeArg*> replacement_defs;
|
||||
auto id_input = node.InputDefs()[0];
|
||||
auto id_output = node.OutputDefs()[0];
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class EliminateIdentity : public RewriteRule {
|
|||
private:
|
||||
bool SatisfyCondition(const Node& node) override;
|
||||
|
||||
Status Apply(GraphEditor& graph_editor, Node& node, bool& modified) override;
|
||||
Status Apply(Graph& graph_editor, Node& node, bool& modified) override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue