diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 838a03ca5f..398ce6aa97 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -51,6 +51,14 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod return false; } + auto does_node_produce_graph_output = [&graph_viewer](const Node* node_ptr) { + return graph_viewer.NodeProducesGraphOutput(*node_ptr); + }; + + if (std::any_of(dq_nodes.begin(), dq_nodes.end(), does_node_produce_graph_output)) { + return false; + } + if (q_nodes.empty()) { return is_empty_q_nodes_allowed; } diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 039bce599a..b253273c5b 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2760,5 +2760,61 @@ TEST(QDQTransformerTests, QDQFinalCleanupTransformer_GraphInputToOutput) { test_case(false); } +// Not fuse if DQ produces graph output +TEST(QDQTransformerTests, DQ_Produce_Graph_Output) { + auto test_case = [&](const std::vector& input_shape, int64_t axis) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input_shape, -5.f, 5.f); + auto* dq_output_arg = builder.MakeOutput(); + auto* output_arg = builder.MakeOutput(); + // add input QDQ + auto* input_q_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input_arg, + .105f, + 127, + input_q_output); + builder.AddDequantizeLinearNode(input_q_output, + .105f, + 127, + dq_output_arg); + + // add Softmax + auto* softmax_output = builder.MakeIntermediate(); + auto& softmax_node = builder.AddNode("Softmax", {dq_output_arg}, {softmax_output}); + softmax_node.AddAttribute("axis", axis); + + // add output QDQ + auto* q_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(softmax_output, + 1.0f / (std::numeric_limits::max() + 1), + 0, + q_output); + builder.AddDequantizeLinearNode(q_output, + 1.0f / (std::numeric_limits::max() + 1), + 0, + output_arg); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.QLinearSoftmax"], 0); + EXPECT_EQ(op_to_count["Softmax"], 1); + EXPECT_EQ(op_to_count["QuantizeLinear"], 2); + EXPECT_EQ(op_to_count["DequantizeLinear"], 2); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + 0.01 /*per_sample_tolerance*/, + 0.01 /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed())); + }; + + test_case({1, 12, 37}, -1); +} + } // namespace test } // namespace onnxruntime