From e83993bbafce60f3698a43933bb01ddb03610636 Mon Sep 17 00:00:00 2001 From: Aung T Naing Date: Tue, 20 Jun 2023 13:58:56 -0700 Subject: [PATCH] Added MatMul tests for QNN EP (#15956) ### Description Added test coverage for QNN EP MatMul op ### Motivation and Context Created test coverage for HTP based MatMul with broadcasting. --------- Co-authored-by: Hector Li --- .../test/providers/qnn/matmul_test.cpp | 180 ++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 onnxruntime/test/providers/qnn/matmul_test.cpp diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp new file mode 100644 index 0000000000..cf9f637954 --- /dev/null +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" + +#include "onnx/onnx_pb.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Returns a function that creates a graph with MatMul operator. +static GetTestModelFn BuildMatMulOpTestCase(const std::vector& input1_shape, + const std::vector& input2_shape) { + return [input1_shape, input2_shape](ModelTestBuilder& builder) { + // Random input data + auto input1 = builder.MakeInput(input1_shape, 0.0f, 10.0f); + auto input2 = builder.MakeInput(input2_shape, 0.0f, 10.0f); + + auto* output = builder.MakeOutput(); + builder.AddNode("MatMul", {input1, input2}, {output}); + }; +} + +// Returns a function that creates a graph with a QDQ AveragePool operator. +template +GetQDQTestCaseFn BuildMatMulOpQDQTestCase(const std::vector& input1_shape, + const std::vector& input2_shape) { + return [input1_shape, input2_shape](ModelTestBuilder& builder) { + float pool_output_scale = 0.0038f; + float q_scale = 0.0038f; + QuantType pool_output_zp = std::numeric_limits::max() / 2; + QuantType q_zp = std::numeric_limits::max() / 2; + + auto* input_arg = builder.MakeInput(input1_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + using InputLimits = std::numeric_limits; + + // add QDQ input + auto* q1_output = builder.MakeIntermediate(); + auto* dq1_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input_arg, + pool_output_scale, + pool_output_zp, + q1_output); + builder.AddDequantizeLinearNode(q1_output, + q_scale, + q_zp, + dq1_output); + + // add input b initializer (NNAPI only supports case of MatMul A*B - B is an initializer) + auto* dq_2_output = builder.MakeIntermediate(); + auto* input_b = builder.MakeInitializer(input2_shape, InputLimits::min(), InputLimits::max()); + builder.AddDequantizeLinearNode(input_b, + q_scale, + q_zp, + dq_2_output); + + // add MatMul operator + auto* matmul_op_output = builder.MakeIntermediate(); + builder.AddNode("MatMul", {dq1_output, dq_2_output}, {matmul_op_output}); + + // add QDQ output + auto* q3_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(matmul_op_output, + pool_output_scale, + pool_output_zp, + q3_output); + builder.AddDequantizeLinearNode(q3_output, + q_scale, + q_zp, + output_arg); + }; +} + +// Runs an AveragePool model on the QNN CPU backend. Checks the graph node assignment, and that inference +// outputs for QNN and CPU match. +static void RunMatMulOpOpTest(const std::vector& input1_shape, + const std::vector& input2_shape, + ExpectedEPNodeAssignment expected_ep_assignment, const char* test_description, + int opset = 13) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + constexpr int expected_nodes_in_partition = 1; + RunQnnModelTest(BuildMatMulOpTestCase(input1_shape, input2_shape), + provider_options, + opset, + expected_ep_assignment, + expected_nodes_in_partition, + test_description); +} + +// Runs a QDQ AveragePool model on the QNN HTP backend. Checks the graph node assignment, and that inference +// outputs for QNN and CPU match. +template +static void RunQDQMatMulOpOpTest(const std::vector& input1_shape, + const std::vector& input2_shape, + ExpectedEPNodeAssignment expected_ep_assignment, const char* test_description, + int opset = 18, float fp32_abs_err = 1e-5f) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + constexpr int expected_nodes_in_partition = 1; + RunQnnModelTest(BuildMatMulOpQDQTestCase(input1_shape, input2_shape), + provider_options, + opset, + expected_ep_assignment, + expected_nodes_in_partition, + test_description, + fp32_abs_err); +} + +// +// CPU tests: +// + +TEST_F(QnnCPUBackendTests, TestMatMulOp) { + RunMatMulOpOpTest({2, 2} /* input_shape1 */, + {2, 2} /* input_shape2 */, + ExpectedEPNodeAssignment::All, "TestMatMulOp", 18); +} + +// QNN broadcast issue +TEST_F(QnnCPUBackendTests, DISABLED_TestMatMulOp2) { + RunMatMulOpOpTest({28, 1, 64} /* input_shape1 */, + {64, 32} /* input_shape2 */, + ExpectedEPNodeAssignment::All, "TestMatMulOp2", 18); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +TEST_F(QnnHTPBackendTests, TestMatMulOp_HTP_u8) { + RunQDQMatMulOpOpTest({2, 2} /* input_shape1 */, + {2, 2} /* input_shape2 */, + ExpectedEPNodeAssignment::All, "TestMatMulOp_HTP_u8", + 18, 0.00381f); +} + +// QNN broadcast issue +TEST_F(QnnHTPBackendTests, DISABLED_TestMatMulOp2_HTP_u8) { + RunQDQMatMulOpOpTest({28, 1, 64} /* input_shape1 */, + {64, 32} /* input_shape2 */, + ExpectedEPNodeAssignment::All, "TestMatMulOp2_HTP_u8", + 18, 0.00381f); +} + +// QNN broadcast issue +TEST_F(QnnHTPBackendTests, DISABLED_TestMatMulOp3_HTP_u8) { + RunQDQMatMulOpOpTest({28, 1, 32} /* input_shape1 */, + {32, 2} /* input_shape2 */, + ExpectedEPNodeAssignment::All, "TestMatMulOp3_HTP_u8", + 18, 0.00381f); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD)