From ab1713f5ccdd643603b30fb7d526b2cf69db9c45 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 24 Feb 2021 19:10:14 -0800 Subject: [PATCH] Fix regression in constant folding optimizer (#6795) --- .../core/optimizer/constant_folding.cc | 11 +++++++-- .../test/optimizer/graph_transform_test.cc | 23 ++++++++++++++++++ ...nstant_folding_remove_dangling_inputs.onnx | Bin 0 -> 344 bytes 3 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/constant_folding_remove_dangling_inputs.onnx diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index 1c53bba193..047aa6f7c1 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -182,9 +182,16 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, if (converted_to_constant) { // Remove single-output node chain for inputs of the node - for (auto p_ip_node = node->InputNodesBegin(); p_ip_node != node->InputNodesEnd(); ++p_ip_node) { - graph_utils::RemoveNodesWithOneOutputBottomUp(graph, *p_ip_node); + auto p_ip_node = node->InputNodesBegin(); + const auto p_ip_node_end = node->InputNodesEnd(); + while (p_ip_node != p_ip_node_end) { + const auto& input_node = *p_ip_node; + // Update the node iterator before removing the corresponding node because removing + // the node will invalidate the node iterator + ++p_ip_node; + graph_utils::RemoveNodesWithOneOutputBottomUp(graph, input_node); } + // Remove the output edges of the constant node and then remove the node itself. graph_utils::RemoveNodeOutputEdges(graph, *node); graph.RemoveNode(node->Index()); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 8befa2ae28..9a01db9ee4 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -363,6 +363,29 @@ TEST_F(GraphTransformationTests, ConstantFoldingWithDequantizeLinear) { VerifyConstantFoldingWithDequantizeLinear(1, 1, 1, graph, session_options, *logger_); } +TEST_F(GraphTransformationTests, ConstantFolding_RemoveDanglingInputNodesToConstantFoldedNode) { + auto model_uri = MODEL_FOLDER "fusion/constant_folding_remove_dangling_inputs.onnx"; + std::shared_ptr model; + ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK()); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Shape"] == 1); // Shape node that will be constant folded + ASSERT_TRUE(op_to_count["Add"] == 1); // Input node to Shape + ASSERT_TRUE(op_to_count["RandomUniform"] == 1); // Input node to Add + + std::unique_ptr e = + onnxruntime::make_unique(CPUExecutionProviderInfo()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Shape"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["RandomUniform"] == 0); +} + TEST_F(GraphTransformationTests, ShapeToInitializer) { auto model_uri = MODEL_FOLDER "shape-add.onnx"; std::shared_ptr model; diff --git a/onnxruntime/test/testdata/transform/fusion/constant_folding_remove_dangling_inputs.onnx b/onnxruntime/test/testdata/transform/fusion/constant_folding_remove_dangling_inputs.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9b3d4e5d4f7b0d82d2d5609df6a89dd3057b0b71 GIT binary patch literal 344 zcmZutO-sW-6zto4tKK`na`_9THyKH3yLReKNwFNzm0%W5NywAqqurT&}z75|x= zst3Ub!whfc@a9pU1?vx^>z4(0lJkqpPl96x2|v77)m*P0tMXOXD;bM}KZ__tuUV*d zF(p%H>b4Nz5ki6xGf262HJ{5wIF|4jA!U&9;Gv7lRCqgq8G{kWnXVe6DkC#Ne|0oK z?_Di7Mbr<2TiW}X7F;|-$Ph*h#+=?LvnXmg68^V}_R&?_vKjPRAb5&kd*s}osm92$ za669nkQhCHE@01%5po6-?%i#cG8c{|7VdsAd@Iec*sgVLo=&<_3_}X6kAxBn$HINK L#^gWanuPr?*G@*- literal 0 HcmV?d00001