diff --git a/onnxruntime/core/optimizer/common_subexpression_elimination.cc b/onnxruntime/core/optimizer/common_subexpression_elimination.cc index cabf227cf7..e94d8179e7 100644 --- a/onnxruntime/core/optimizer/common_subexpression_elimination.cc +++ b/onnxruntime/core/optimizer/common_subexpression_elimination.cc @@ -318,7 +318,13 @@ struct NodeArgPtrEquality { }; bool IsNodeSupported(const Node& node) { - return !node.ContainsSubgraph() && optimizer_utils::IsOperationDeterministic(node.Domain(), node.OpType()); + // skip control flow nodes, nodes that produce non-deterministic output, and DequantizeLinear (DQ) nodes. + // the reason for skipping DQ is that the QDQ handling looks for QDQ node groups (DQ -> fp32 node -> Q node) + // and does not allow for a DQ node to be used in multiple groups. coalescing multiple DQ nodes into one + // would result in it having multiple consumers for its output, and it being used in multiple QDQ node groups. + return !node.ContainsSubgraph() && + optimizer_utils::IsOperationDeterministic(node.Domain(), node.OpType()) && + !(node.Domain() == kOnnxDomain && node.OpType() == "DequantizeLinear"); } } // namespace diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 7e46387ce5..f07a664ba0 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -149,6 +149,24 @@ TEST_F(GraphTransformationTests, IdentityWithSharedNodeArgNotEliminated) { ASSERT_TRUE(op_to_count["Add"] == 1); } +TEST_F(GraphTransformationTests, DequantizeLinearNodeNotEliminated) { + auto model_uri = MODEL_FOLDER "qdq_with_multi_consumer_dq_nodes.fixed.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["DequantizeLinear"], 25); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + // CommonSubexpressionElimination should skip the DequantizeLinear nodes + op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["DequantizeLinear"], 25); +} + TEST_F(GraphTransformationTests, IdentityInputIsGraphOutputNotEliminated) { auto model_uri = MODEL_FOLDER "scan9_sum.onnx"; std::shared_ptr model; diff --git a/onnxruntime/test/testdata/transform/qdq_with_multi_consumer_dq_nodes.fixed.onnx b/onnxruntime/test/testdata/transform/qdq_with_multi_consumer_dq_nodes.fixed.onnx new file mode 100644 index 0000000000..a74c9d8ff5 Binary files /dev/null and b/onnxruntime/test/testdata/transform/qdq_with_multi_consumer_dq_nodes.fixed.onnx differ diff --git a/onnxruntime/test/testdata/transform/qdq_with_multi_consumer_dq_nodes.fixed.txt b/onnxruntime/test/testdata/transform/qdq_with_multi_consumer_dq_nodes.fixed.txt new file mode 100644 index 0000000000..ec0e17dda0 --- /dev/null +++ b/onnxruntime/test/testdata/transform/qdq_with_multi_consumer_dq_nodes.fixed.txt @@ -0,0 +1,6 @@ +Model was created by running + +python -m onnxruntime.tools.qdq_helpers.optimize_qdq_model \onnxruntime\test\testdata\qdq_with_multi_consumer_dq_nodes.onnx \onnxruntime\test\testdata\transform +qdq_with_multi_consumer_dq_nodes.fixed.onnx + +This results in a model that has duplicated DequantizeLinear (DQ) nodes so that each QDQ node group has no shared nodes. The CommonSubexpressionElimination optimizer could potentially combine these duplicated DQ nodes, which would break the QDQ handling. Due to this we ignore DQ nodes in that optimizer. \ No newline at end of file