remove graph editor since it's not designed as expected to restrict graph access for rewrite rule. (#119)

This commit is contained in:
Ke Zhang 2018-12-06 14:10:51 -08:00 committed by GitHub
parent 194d0a98d1
commit 830c341c19
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 7 additions and 60 deletions

View file

@ -8,58 +8,6 @@
namespace onnxruntime {
/**
@class GraphEditor
The API for Graph rewrite rules.
*/
class GraphEditor {
public:
explicit GraphEditor(Graph& graph) noexcept : graph_{graph} {}
/** Add a node to this Graph */
Node& AddNode(const std::string& name,
const std::string& op_type,
const std::string& description,
const std::vector<NodeArg*>& input_args,
const std::vector<NodeArg*>& output_args,
const std::string& domain = "") {
return graph_.AddNode(name, op_type, description,
input_args, output_args, nullptr, domain);
}
/** Copy an existing node into the Graph. */
Node& AddNode(const Node& other) {
return graph_.AddNode(other);
}
/** Remove a node from the Graph. */
bool RemoveNode(NodeIndex node_index) {
return graph_.RemoveNode(node_index);
}
/** Add a control edge between two Nodes in the Graph
The <dst> node does not consume any data output by <src>, so there is no input/output edge between them,
but dst must executed after src so a control edge is required.
@param src NodeIndex from the Graph of the Node which must execute first.
@param dst NodeIndex from the Graph of the Node which must execute after src.
*/
bool AddControlEdge(NodeIndex src, NodeIndex dst) {
return graph_.AddControlEdge(src, dst);
}
/** Resolve the Graph.
@returns Status with success or error information.
@remarks Resolve must be called after modifying the Graph is completed. */
common::Status Resolve() {
return graph_.Resolve();
}
private:
ONNXRUNTIME_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphEditor);
Graph& graph_;
};
/**
@class RewriteRule
@ -94,8 +42,8 @@ class RewriteRule {
@param[in] node The Node to apply the rewrite to.
@param[out] modified Set to indicate whether the node was modified or not.
@returns Status indicating success or providing error information */
common::Status CheckConditionAndApply(GraphEditor& graph_editor, Node& node, bool& modified) {
return SatisfyCondition(node) ? Apply(graph_editor, node, modified) : Status::OK();
common::Status CheckConditionAndApply(Graph& graph, Node& node, bool& modified) {
return SatisfyCondition(node) ? Apply(graph, node, modified) : Status::OK();
}
private:
@ -115,6 +63,6 @@ class RewriteRule {
The transformation happens in-place. The return-value of node may be different
from the input-value due to rewriting.
The value of "modified" indicates if the graph was modified or not. */
virtual common::Status Apply(GraphEditor& graph_editor, Node& node, bool& modified) = 0;
virtual common::Status Apply(Graph& graph, Node& node, bool& modified) = 0;
};
} // namespace onnxruntime

View file

@ -18,7 +18,7 @@ class FunctionInliner : public onnxruntime::RewriteRule {
FunctionInliner(const std::string& name, const std::string& desc)
: RewriteRule(name, desc) {}
Status Apply(onnxruntime::GraphEditor /*graph_editor*/, onnxruntime::Node* /*node*/, bool* /*modified*/) override {
Status Apply(onnxruntime::Graph/*graph_editor*/, onnxruntime::Node* /*node*/, bool* /*modified*/) override {
return Status::OK();
}
};

View file

@ -20,7 +20,6 @@ Status TopDownRuleBasedTransformer::Apply(Graph& graph, bool& modified) const {
ONNXRUNTIME_RETURN_IF_ERROR(graph.Resolve());
GraphViewer graph_viewer(graph);
auto& order = graph_viewer.GetNodesInTopologicalOrder();
GraphEditor graph_editor(graph);
for (NodeIndex i : order) {
auto node = graph.GetNode(i);
@ -34,7 +33,7 @@ Status TopDownRuleBasedTransformer::Apply(Graph& graph, bool& modified) const {
continue;
for (const auto& rule : *rules) {
ONNXRUNTIME_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph_editor, *node, modified));
ONNXRUNTIME_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, *node, modified));
}
}

View file

@ -9,7 +9,7 @@
namespace onnxruntime {
Status EliminateIdentity::Apply(GraphEditor& graph_editor, Node& node, bool& modified) {
Status EliminateIdentity::Apply(Graph& graph_editor, Node& node, bool& modified) {
std::map<const NodeArg*, NodeArg*> replacement_defs;
auto id_input = node.InputDefs()[0];
auto id_output = node.OutputDefs()[0];

View file

@ -15,7 +15,7 @@ class EliminateIdentity : public RewriteRule {
private:
bool SatisfyCondition(const Node& node) override;
Status Apply(GraphEditor& graph_editor, Node& node, bool& modified) override;
Status Apply(Graph& graph_editor, Node& node, bool& modified) override;
};
} // namespace onnxruntime