From 4f303412539bcc51f0d8a09e2a510b799afe558d Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Tue, 30 Mar 2021 09:00:48 -0700 Subject: [PATCH] Check the count of DequantizeLinear for matmul (#7174) --- .../optimizer/qdq_transformer/qdq_matmul.cc | 2 +- .../test/optimizer/qdq_transformer_test.cc | 33 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_matmul.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_matmul.cc index 6b7a578556..770b435151 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_matmul.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_matmul.cc @@ -13,7 +13,7 @@ class QDQMatMulTransformer : public QDQOperatorTransformer { QDQMatMulTransformer(Node& node, Graph& graph) : QDQOperatorTransformer(node, graph) {} bool Transform(const std::vector& parents, const std::vector& children) override { - if (children.size() != 1) { + if (parents.size() != 2 || children.size() != 1) { return false; } diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index d056a4b7b8..0cbbeacd24 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -231,6 +231,39 @@ TEST(QDQTransformerTests, MatMul) { test_case({22, 11, 13, 15}, {15, 13}); } +TEST(QDQTransformerTests, MatMul_No_Fusion) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + // add QDQ + MatMul + auto* matmul_output = builder.MakeIntermediate(); + auto* dq_matmul_output1 = AddQDQNodePair(builder, input1_arg, .004f, 129); + builder.AddNode("MatMul", {dq_matmul_output1, input2_arg}, {matmul_output}); + + // add Q + builder.AddQuantizeLinearNode(matmul_output, .0039f, 135, output_arg); + }; + + auto check_matmul_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["QLinearMatMul"], 0); + EXPECT_EQ(op_to_count["QuantizeLinear"], 2); + EXPECT_EQ(op_to_count["DequantizeLinear"], 1); + }; + + TransformerTester(build_test_case, check_matmul_graph, TransformerLevel::Level1, TransformerLevel::Level2); + }; + + // Test the basic case of a single 1D/2D/3D convolution. + test_case({12, 37}, {37, 12}); + test_case({23, 13, 13}, {13, 13}); + test_case({22, 11, 13, 15}, {15, 13}); +} + #endif // DISABLE_CONTRIB_OPS } // namespace test