diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index 1c53bba193..047aa6f7c1 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -182,9 +182,16 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, if (converted_to_constant) { // Remove single-output node chain for inputs of the node - for (auto p_ip_node = node->InputNodesBegin(); p_ip_node != node->InputNodesEnd(); ++p_ip_node) { - graph_utils::RemoveNodesWithOneOutputBottomUp(graph, *p_ip_node); + auto p_ip_node = node->InputNodesBegin(); + const auto p_ip_node_end = node->InputNodesEnd(); + while (p_ip_node != p_ip_node_end) { + const auto& input_node = *p_ip_node; + // Update the node iterator before removing the corresponding node because removing + // the node will invalidate the node iterator + ++p_ip_node; + graph_utils::RemoveNodesWithOneOutputBottomUp(graph, input_node); } + // Remove the output edges of the constant node and then remove the node itself. graph_utils::RemoveNodeOutputEdges(graph, *node); graph.RemoveNode(node->Index()); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 8befa2ae28..9a01db9ee4 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -363,6 +363,29 @@ TEST_F(GraphTransformationTests, ConstantFoldingWithDequantizeLinear) { VerifyConstantFoldingWithDequantizeLinear(1, 1, 1, graph, session_options, *logger_); } +TEST_F(GraphTransformationTests, ConstantFolding_RemoveDanglingInputNodesToConstantFoldedNode) { + auto model_uri = MODEL_FOLDER "fusion/constant_folding_remove_dangling_inputs.onnx"; + std::shared_ptr model; + ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK()); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Shape"] == 1); // Shape node that will be constant folded + ASSERT_TRUE(op_to_count["Add"] == 1); // Input node to Shape + ASSERT_TRUE(op_to_count["RandomUniform"] == 1); // Input node to Add + + std::unique_ptr e = + onnxruntime::make_unique(CPUExecutionProviderInfo()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Shape"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["RandomUniform"] == 0); +} + TEST_F(GraphTransformationTests, ShapeToInitializer) { auto model_uri = MODEL_FOLDER "shape-add.onnx"; std::shared_ptr model; diff --git a/onnxruntime/test/testdata/transform/fusion/constant_folding_remove_dangling_inputs.onnx b/onnxruntime/test/testdata/transform/fusion/constant_folding_remove_dangling_inputs.onnx new file mode 100644 index 0000000000..9b3d4e5d4f Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/constant_folding_remove_dangling_inputs.onnx differ