mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-28 03:20:58 +00:00
Fix regression in constant folding optimizer (#6795)
This commit is contained in:
parent
40fa40f3ce
commit
ab1713f5cc
3 changed files with 32 additions and 2 deletions
|
|
@ -182,9 +182,16 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
|
||||
if (converted_to_constant) {
|
||||
// Remove single-output node chain for inputs of the node
|
||||
for (auto p_ip_node = node->InputNodesBegin(); p_ip_node != node->InputNodesEnd(); ++p_ip_node) {
|
||||
graph_utils::RemoveNodesWithOneOutputBottomUp(graph, *p_ip_node);
|
||||
auto p_ip_node = node->InputNodesBegin();
|
||||
const auto p_ip_node_end = node->InputNodesEnd();
|
||||
while (p_ip_node != p_ip_node_end) {
|
||||
const auto& input_node = *p_ip_node;
|
||||
// Update the node iterator before removing the corresponding node because removing
|
||||
// the node will invalidate the node iterator
|
||||
++p_ip_node;
|
||||
graph_utils::RemoveNodesWithOneOutputBottomUp(graph, input_node);
|
||||
}
|
||||
|
||||
// Remove the output edges of the constant node and then remove the node itself.
|
||||
graph_utils::RemoveNodeOutputEdges(graph, *node);
|
||||
graph.RemoveNode(node->Index());
|
||||
|
|
|
|||
|
|
@ -363,6 +363,29 @@ TEST_F(GraphTransformationTests, ConstantFoldingWithDequantizeLinear) {
|
|||
VerifyConstantFoldingWithDequantizeLinear(1, 1, 1, graph, session_options, *logger_);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, ConstantFolding_RemoveDanglingInputNodesToConstantFoldedNode) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/constant_folding_remove_dangling_inputs.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK());
|
||||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Shape"] == 1); // Shape node that will be constant folded
|
||||
ASSERT_TRUE(op_to_count["Add"] == 1); // Input node to Shape
|
||||
ASSERT_TRUE(op_to_count["RandomUniform"] == 1); // Input node to Add
|
||||
|
||||
std::unique_ptr<CPUExecutionProvider> e =
|
||||
onnxruntime::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1);
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Shape"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["RandomUniform"] == 0);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, ShapeToInitializer) {
|
||||
auto model_uri = MODEL_FOLDER "shape-add.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/constant_folding_remove_dangling_inputs.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/constant_folding_remove_dangling_inputs.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue