Fix regression in constant folding optimizer (#6795)

This commit is contained in:
Hariharan Seshadri 2021-02-24 19:10:14 -08:00 committed by GitHub
parent 40fa40f3ce
commit ab1713f5cc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 2 deletions

View file

@ -182,9 +182,16 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
if (converted_to_constant) {
// Remove single-output node chain for inputs of the node
for (auto p_ip_node = node->InputNodesBegin(); p_ip_node != node->InputNodesEnd(); ++p_ip_node) {
graph_utils::RemoveNodesWithOneOutputBottomUp(graph, *p_ip_node);
auto p_ip_node = node->InputNodesBegin();
const auto p_ip_node_end = node->InputNodesEnd();
while (p_ip_node != p_ip_node_end) {
const auto& input_node = *p_ip_node;
// Update the node iterator before removing the corresponding node because removing
// the node will invalidate the node iterator
++p_ip_node;
graph_utils::RemoveNodesWithOneOutputBottomUp(graph, input_node);
}
// Remove the output edges of the constant node and then remove the node itself.
graph_utils::RemoveNodeOutputEdges(graph, *node);
graph.RemoveNode(node->Index());

View file

@ -363,6 +363,29 @@ TEST_F(GraphTransformationTests, ConstantFoldingWithDequantizeLinear) {
VerifyConstantFoldingWithDequantizeLinear(1, 1, 1, graph, session_options, *logger_);
}
TEST_F(GraphTransformationTests, ConstantFolding_RemoveDanglingInputNodesToConstantFoldedNode) {
auto model_uri = MODEL_FOLDER "fusion/constant_folding_remove_dangling_inputs.onnx";
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK());
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Shape"] == 1); // Shape node that will be constant folded
ASSERT_TRUE(op_to_count["Add"] == 1); // Input node to Shape
ASSERT_TRUE(op_to_count["RandomUniform"] == 1); // Input node to Add
std::unique_ptr<CPUExecutionProvider> e =
onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Shape"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["RandomUniform"] == 0);
}
TEST_F(GraphTransformationTests, ShapeToInitializer) {
auto model_uri = MODEL_FOLDER "shape-add.onnx";
std::shared_ptr<Model> model;