From f7bf5a19baf0a7caa9cca7dc08bf192e392a14e4 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Thu, 12 Sep 2024 17:18:50 -0700 Subject: [PATCH] [QNN EP] Ensure QNN EP rejects nodes with I/O of dynamic shape (#22066) ### Description Updates QNN EP to properly reject nodes that have inputs or outputs with dynamic shapes. ### Motivation and Context Currently, QNN EP does not properly offload subgraphs with dynamic shapes to the CPU EP. This PR ensures that QNN EP rejects nodes that consume or generate I/O with dynamic shapes. --- .../qnn/builder/qnn_model_wrapper.cc | 4 +- .../test/providers/qnn/qnn_basic_test.cc | 57 +++++++++++++++++++ .../test/providers/qnn/qnn_test_utils.cc | 4 +- .../test/providers/qnn/qnn_test_utils.h | 6 +- 4 files changed, 68 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 3c029fda9c..2c7f3c8b22 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -308,8 +308,10 @@ bool QnnModelWrapper::GetOnnxShape(const NodeArg& node_arg, std::vectordim()) { + if (!dim.has_dim_value()) { + return false; // Do not support dynamic shapes. + } shape.push_back(SafeInt(dim.dim_value())); } diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 9d19c36dc9..c4367aeb52 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -948,6 +948,63 @@ TEST_F(QnnHTPBackendTests, Float32ModelWithFP16PrecisionTest) { 0.008f); } +// Test that QNN EP only handles nodes with static shapes and rejects nodes with dynamic shape I/O. +TEST_F(QnnHTPBackendTests, EPRejectsDynamicShapesF32) { + // Local function that builds a model in which the last two nodes use dynamic shapes. + auto model_build_fn = [](ModelTestBuilder& builder) { + NodeArg* input1 = builder.MakeInput(std::vector{1, 2, 8, 8}, + GetFloatDataInRange(0.0f, 1.0f, 128)); + NodeArg* input2 = builder.MakeInput(std::vector{3}, std::vector{1, 2, 49}); + + // Add a Conv with known shapes. QNN EP should support it. + NodeArg* weight = builder.MakeInitializer(std::vector{2, 2, 2, 2}, + GetFloatDataInRange(-0.3f, 0.3f, 16)); + NodeArg* bias = builder.MakeInitializer(std::vector{2}, {0.0f, 1.0f}); + + auto* conv_output = builder.MakeIntermediate(); + builder.AddNode("Conv", {input1, weight, bias}, {conv_output}); + + // Add a Reshape to a dynamic shape. QNN EP should reject this node. + auto* reshape_output = builder.MakeIntermediate(); + builder.AddNode("Reshape", {conv_output, input2}, {reshape_output}); + + // Add a Softmax. QNN EP should reject this node because its input has a dynamic shape. + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Softmax", {reshape_output}, {output}); + }; + + // Local function that checks that the nodes with dynamic shape I/O were assigned to CPU EP. + std::function ep_graph_checker = [](const Graph& graph) { + for (const Node& node : graph.Nodes()) { + const std::string& ep_name = node.GetExecutionProviderType(); + const std::string& op_type = node.OpType(); + if (op_type == "Reshape" || op_type == "Softmax") { + EXPECT_EQ(ep_name, kCpuExecutionProvider); + } else { + EXPECT_EQ(ep_name, kQnnExecutionProvider); + } + } + }; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["enable_htp_fp16_precision"] = "1"; // QNN EP will use fp16 precision. + // CPU EP will use fp32, so we can relax accuracy requirements. + + RunQnnModelTest(model_build_fn, + provider_options, + /*opset*/ 19, + ExpectedEPNodeAssignment::Some, + /*abs_err*/ 1e-4f, + logging::Severity::kERROR, + /*verify_output*/ true, + &ep_graph_checker); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index afaa5a341d..8a4f7f2a1f 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -98,10 +98,12 @@ void TryEnableQNNSaver(ProviderOptions& qnn_options) { void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, - float fp32_abs_err, logging::Severity log_severity, bool verify_outputs) { + float fp32_abs_err, logging::Severity log_severity, bool verify_outputs, + std::function* ep_graph_checker) { EPVerificationParams verification_params; verification_params.ep_node_assignment = expected_ep_assignment; verification_params.fp32_abs_err = fp32_abs_err; + verification_params.graph_verifier = ep_graph_checker; // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 3a6753e9b6..bb77c92668 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -1033,12 +1033,16 @@ inline GetTestQDQModelFn BuildQDQOpTestCase( * \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None). * \param fp32_abs_err The acceptable error between CPU EP and QNN EP. * \param log_severity The logger's minimum severity level. + * \param verify_outputs True to verify that the outputs match (within tolerance). + * \param ep_graph_checker Function called on the Graph generated for the EP's session. Used to check node + * EP assignment. */ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-5f, logging::Severity log_severity = logging::Severity::kERROR, - bool verify_outputs = true); + bool verify_outputs = true, + std::function* ep_graph_checker = nullptr); enum class BackendSupport { SUPPORT_UNKNOWN,