From 7abe1fd3920686a01d490a6530f2437fb2f221ad Mon Sep 17 00:00:00 2001 From: ytaous <4484531+ytaous@users.noreply.github.com> Date: Mon, 19 Apr 2021 16:36:35 -0700 Subject: [PATCH] Identity elimination with graph output (#7312) * Identity removal * fix build * fix build * fix build * fix builld * UTs * fix UT * fix UTs * per comments * fix UTs * fix UTs * per comments Co-authored-by: Ethan Tao --- .../core/optimizer/identity_elimination.cc | 81 ++++++++++++++++-- .../test/optimizer/graph_transform_test.cc | 64 ++++++++++++++ .../test/testdata/transform/id-elim.onnx | Bin 0 -> 288 bytes .../test/testdata/transform/id-elim.py | 41 +++++++++ .../test/testdata/transform/id-scan9_sum.py | 61 +++++++++++++ .../test/testdata/transform/scan9_sum.onnx | Bin 0 -> 354 bytes .../python/orttraining_test_ortmodule_api.py | 24 ++++++ 7 files changed, 266 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/id-elim.onnx create mode 100644 onnxruntime/test/testdata/transform/id-elim.py create mode 100644 onnxruntime/test/testdata/transform/id-scan9_sum.py create mode 100644 onnxruntime/test/testdata/transform/scan9_sum.onnx diff --git a/onnxruntime/core/optimizer/identity_elimination.cc b/onnxruntime/core/optimizer/identity_elimination.cc index 944d01928d..77a58f7c57 100644 --- a/onnxruntime/core/optimizer/identity_elimination.cc +++ b/onnxruntime/core/optimizer/identity_elimination.cc @@ -10,16 +10,87 @@ namespace onnxruntime { -Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { - if (graph_utils::RemoveNode(graph, node)) { - rule_effect = RewriteRuleEffect::kRemovedCurrentNode; - } +/** + Case to eliminate Identity node when + - the input nodearg has only one consumer, which is the Identity itself + - the input def is not a graph output + + For examples: + OK to eliminate: + + Identity output is another node, and the Identity is the only consumer of X + X ---> Identity ---> Y where Y could be graph output + + Identity input arg is not shared with other output arg of X + + (arg0) ---> Identity0 ---> Z + | + X (arg1) ---> Identity1 ---> Y + + Not OK to eliminate: + + Identity input arg, i.e., arg0, is also an input arg of other Identity + + (arg0) ---> Identity0 ---> Z + | + X (arg0) ---> Identity1 ---> Y + + Identity input def, i.e., def0, is also a graph output + + (def0) ---> Z where Z is graph output + | + X (def0/arg0) ---> Identity ---> Y + */ +Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + if (graph.GetNodeOutputsInGraphOutputs(node).empty()) { + if (graph_utils::RemoveNode(graph, node)) { + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + } + } else { + // keep a reference of output def to the graph output + NodeArg* output = node.MutableOutputDefs()[0]; + const Node* p_input_node = graph_utils::GetInputNode(node, 0); + // get mutable input node + Node& input_node = *graph.GetNode(p_input_node->Index()); + int output_idx = graph_utils::GetNodeOutputIndexFromOutputName(input_node, node.MutableInputDefs()[0]->Name()); + // remove Identity node and its input edge + graph.RemoveNode(node.Index()); + // update input node's output def to the graph output + input_node.MutableOutputDefs()[output_idx] = output; + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + } return Status::OK(); } bool EliminateIdentity::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { - return graph_utils::CanRemoveNode(graph, node, logger); + if (graph_utils::CanRemoveNode(graph, node, logger)) { + return true; + } + + // relax the condition if Identity is connecting to graph output + if (node.GetOutputEdgesCount() != 0 || node.OutputDefs().size() != 1 || + graph.GetNodeOutputsInGraphOutputs(node).empty()) + return false; + + const Node* p_input_node = graph_utils::GetInputNode(node, 0); + if (p_input_node == nullptr) + return false; + + // skip if the src arg is also a graph output + int src_arg_index = graph_utils::GetNodeOutputIndexFromOutputName(*p_input_node, node.InputDefs()[0]->Name()); + if (graph.IsOutput(p_input_node->OutputDefs()[src_arg_index])) + return false; + + // count how many consumers are sharing the same src arg + int count = 0; + for (auto it = p_input_node->OutputEdgesBegin(), end = p_input_node->OutputEdgesEnd(); it != end; ++it) { + if (it->GetSrcArgIndex() == src_arg_index) { + count++; + } + } + // condition not met if there are more than 1 consumer for the same src arg + if (count > 1) + return false; + + return true; } } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 5a98c9d2db..fa79304ef3 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -100,6 +100,70 @@ TEST_F(GraphTransformationTests, IdentityElimination) { ASSERT_TRUE(op_to_count["Identity"] == 0); } +TEST_F(GraphTransformationTests, IdentityEliminationWithGraphOutput) { + auto model_uri = MODEL_FOLDER "abs-id.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Identity"] == 1); + + auto rule_transformer_L1 = onnxruntime::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Identity"] == 0); +} + +TEST_F(GraphTransformationTests, IdentityWithSharedNodeArgNotEliminated) { + auto model_uri = MODEL_FOLDER "id-elim.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Identity"] == 2); + ASSERT_TRUE(op_to_count["Add"] == 2); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); + auto rule_transformer_L1 = onnxruntime::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + // after CommonSubexpressionElimination, Add would have 1 output def and 2 edges + // each edge would share the same input node arg 0. Thus after execution, only one of the 2 outputs + // has data. Thus skip. + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Identity"] == 2); + ASSERT_TRUE(op_to_count["Add"] == 1); +} + +TEST_F(GraphTransformationTests, IdentityInputIsGraphOutputNotEliminated) { + auto model_uri = MODEL_FOLDER "scan9_sum.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Identity"] == 1); + + // tips: to dump the subgraph, can use python tool - dump_subgraphs.py + // or click on one of the input to see the drop down graph list and view subgraph + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = onnxruntime::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + // Identity's input in subgraph is also graph output. Thus skip. + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Identity"] == 1); +} + TEST_F(GraphTransformationTests, DropoutElimination) { auto model_uri = MODEL_FOLDER "dropout.onnx"; std::shared_ptr model; diff --git a/onnxruntime/test/testdata/transform/id-elim.onnx b/onnxruntime/test/testdata/transform/id-elim.onnx new file mode 100644 index 0000000000000000000000000000000000000000..95c7ad12ebfb406c8c829d878b92d83e92d5c811 GIT binary patch literal 288 zcmd;J7vjm!%d5~$tw_u*$Vs*O!pJ4b#Z+Mk#706ai76?DQeaYv*)b&rP09!?WdxCu z;{vM_V$UxvDJU&5lw!_IF;wF4Oi9fv$tjnmV%9)cG_9nrIEyJv7i$n8%R8j9>9Zm3ymnB&8H|alj