From 9ef11f1c6a2b8ef3378e06af554f0361a00cbd6a Mon Sep 17 00:00:00 2001 From: Hector Li Date: Mon, 10 Apr 2023 10:36:57 -0700 Subject: [PATCH] [QNN EP] Qnn batchnorm Op support (#15222) ### Description Support BatchNorm Op in Qnn EP Node Unit group support for BatchNorm, Exp ops ### Motivation and Context Enable more models. --- .../selectors_actions/qdq_selectors.cc | 24 ++++ .../selectors_actions/qdq_selectors.h | 21 ++++ .../selectors_actions/shared/utils.cc | 14 ++- .../qnn/builder/op_builder_factory.cc | 4 + .../qnn/builder/op_builder_factory.h | 2 + .../qnn/builder/opbuilder/base_op_builder.cc | 7 +- .../qnn/builder/opbuilder/base_op_builder.h | 30 ++++- .../opbuilder/batch_norm_op_builder.cc | 94 ++++++++++++++ .../qnn/builder/opbuilder/pool_op_builder.cc | 1 - .../providers/qnn/qnn_execution_provider.cc | 8 ++ .../test/providers/qnn/batch_norm_htp_test.cc | 118 ++++++++++++++++++ 11 files changed, 316 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc create mode 100644 onnxruntime/test/providers/qnn/batch_norm_htp_test.cc diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index f9d070044f..ebe5612b66 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -349,6 +349,30 @@ bool InstanceNormalizationNodeGroupSelector::Check(const GraphViewer& graph_view (dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32); } +bool BatchNormalizationNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) { + return false; + } + + int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_scale = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_input != dt_output) { + return false; + } + + if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { + if (!int8_allowed_ || dt_scale != dt_input) { + return false; + } + } + + return true; +} + } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 2efc9419ed..85fd56b1d3 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -147,6 +147,20 @@ class InstanceNormalizationNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; }; +// DQ nodes for X, W and optionally B, not used for mean, var -> node -> Q +class BatchNormalizationNodeGroupSelector : public NodeGroupSelector { + public: + // default to 'true' + BatchNormalizationNodeGroupSelector(bool int8_allowed = true) : int8_allowed_(int8_allowed) {} + + private: + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; + + bool int8_allowed_; +}; + /* * NodeSelector instances for use in the QDQ::SelectorActionTransformer. */ @@ -246,6 +260,13 @@ class InstanceNormalizationSelector : public BaseSelector { : BaseSelector(std::make_unique()) {} }; +// DQ nodes for X, W and optionally B, (mean, var not required) -> node -> Q +class BatchNormalizationSelector : public BaseSelector { + public: + BatchNormalizationSelector(bool int8_allowed = false) + : BaseSelector(std::make_unique(int8_allowed)) {} +}; + } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index 6bc40344ef..0e343bbc60 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -53,7 +53,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() { {"Slice", {}}, {"Softmax", {}}, {"Sqrt", {}}, - {"Tanh", {}}}; + {"Tanh", {}}, + {"Exp", {}}}; } static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() { return {{"Add", {}}, @@ -80,6 +81,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetGemmOpVersionsMap() { static const OpVersionsAndSelector::OpVersionsMap GetInstanceNormalizationOpVersionsMap() { return {{"InstanceNormalization", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetBatchNormalizationOpVersionsMap() { + return {{"BatchNormalization", {}}}; +} /* Selector rules registration related */ void RegisterMiscSelectors(Selectors& qdq_selectors) { @@ -146,6 +150,13 @@ void RegisterInstanceNormalizationSelector(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterBatchNormalizationSelector(Selectors& qdq_selectors) { + /* register selector for BatchNormalization op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetBatchNormalizationOpVersionsMap(), + std::move(selector)); +} + void SelectorManager::CreateSelectors() { RegisterMiscSelectors(qdq_selectors_); RegisterUnarySelectors(qdq_selectors_); @@ -156,6 +167,7 @@ void SelectorManager::CreateSelectors() { RegisterMatMulSelector(qdq_selectors_); RegisterGemmSelector(qdq_selectors_); RegisterInstanceNormalizationSelector(qdq_selectors_); + RegisterBatchNormalizationSelector(qdq_selectors_); } void SelectorManager::InitializeSelectorsMap() { diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index 881d441fde..f2a6695b3a 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -131,6 +131,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { { CreateInstanceNormOpBuilder("InstanceNormalization", *this); } + + { + CreateBatchNormOpBuilder("BatchNormalization", *this); + } } const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) { diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index 2afffb2305..73a4e3344e 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -82,5 +82,7 @@ void CreateInstanceNormOpBuilder(const std::string& op_type, OpBuilderRegistrati void CreateReduceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateBatchNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index a8dd4f3ed3..1dde122bd9 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -147,7 +147,8 @@ Status BaseOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_UNUSED_PARAMETER(do_op_validation); const auto& inputs = node_unit.Inputs(); - for (size_t input_i = 0; input_i < inputs.size(); ++input_i) { + const auto input_count = GetInputCountQnnRequired(node_unit); + for (size_t input_i = 0; input_i < input_count; ++input_i) { ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[input_i], logger, is_quantized_model, input_names)); } @@ -180,7 +181,6 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, // Add output // Output part is common for all Ops, only difference is the Op attribute const auto& outputs = node_unit.Outputs(); - auto output_size = outputs.size(); std::vector output_names; struct CastNodeInfo { std::string node_name; @@ -189,7 +189,8 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, }; std::vector cast_node_info_vec; - for (size_t output_i = 0; output_i < output_size && output_i < output_count_; ++output_i) { + const auto output_count = GetOutputCountQnnRequired(node_unit); + for (size_t output_i = 0; output_i < output_count; ++output_i) { const auto& output_name = outputs[output_i].node_arg.Name(); Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 5f507aad43..818434c7b7 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -162,7 +162,8 @@ class BaseOpBuilder : public IOpBuilder { {"ConvTranspose", "TransposeConv2d"}, {"Tile", "Tile"}, {"TopK", "TopK"}, - {"InstanceNormalization", "InstanceNorm"}}; + {"InstanceNormalization", "InstanceNorm"}, + {"BatchNormalization", "Batchnorm"}}; auto it = onnx_op_type_to_qnn_op_type.find(onnx_op_type); ORT_ENFORCE(it != onnx_op_type_to_qnn_op_type.end()); return it->second; @@ -252,7 +253,32 @@ class BaseOpBuilder : public IOpBuilder { int32_t& default_axis_value) const; Qnn_TensorType_t GetInputTensorType(const QnnModelWrapper& qnn_model_wrapper, const std::string& input_name) const; - mutable size_t output_count_ = std::numeric_limits::max(); + size_t GetInputCountQnnRequired(const NodeUnit& node_unit) const { + auto input_output_cout = GetInputOutputCountQnnRequired(node_unit.OpType()); + + return 0 == input_output_cout.first ? node_unit.Inputs().size() : input_output_cout.first; + } + + size_t GetOutputCountQnnRequired(const NodeUnit& node_unit) const { + auto input_output_cout = GetInputOutputCountQnnRequired(node_unit.OpType()); + + return 0 == input_output_cout.second ? node_unit.Outputs().size() : input_output_cout.second; + } + + private: + static const std::pair GetInputOutputCountQnnRequired(std::string onnx_op_type) { + static const std::unordered_map> input_output_count_qnn_required = { + {"GlobalAveragePool", {0, 1}}, + {"MaxPool", {0, 1}}, + {"BatchNormalization", {3, 1}}}; + + auto pos = input_output_count_qnn_required.find(onnx_op_type); + if (pos == input_output_count_qnn_required.end()) { + return std::make_pair< size_t, size_t>(0, 0); + } else { + return pos->second; + } + } private: std::string op_builder_type_; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc new file mode 100644 index 0000000000..43a0ca9eac --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" + +#include "base_op_builder.h" + +#include + +namespace onnxruntime { +namespace qnn { +class BatchNormOpBuilder : public BaseOpBuilder { + public: + BatchNormOpBuilder() : BaseOpBuilder("BatchNormOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(BatchNormOpBuilder); + + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + bool is_quantized_model) const override final ORT_MUST_USE_RESULT; +}; + +// BatchNorm is sensitive with data layout, no special validation so far +// The nodes from 1st call of GetCapability do not get layout transformer applied, it's still NCHW +// The nodes from 2nd call of GetCapability get layout transformer applied, it's NHWC +Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + bool is_quantized_model) const { + if (node_unit.Domain() == kMSInternalNHWCDomain) { + // It's useless to fallback the node after layout transformation because CPU EP can't support it anyway + // Still do it here so hopefully QNN Op validation API can tell us some details why it's not supported + return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, is_quantized_model, true); + } else { + NodeAttrHelper node_helper(node_unit); + const float default_epsilon = 1e-05f; + const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec. + ORT_RETURN_IF(abs(epsilon - default_epsilon) > default_epsilon, "QNN BatchNorm doesn't support epsilon."); + + const auto float_elem_type = ONNX_NAMESPACE::Utils::DataTypeUtils::ToType("float"); + + const auto& inputs = node_unit.Inputs(); + ORT_ENFORCE(inputs.size() == 5, "5 input expected per BatchNorm Onnx Spec."); + // Check input type is float for CPU. + ONNX_NAMESPACE::DataType input_data_type = inputs[0].node_arg.Type(); + ORT_RETURN_IF(!is_quantized_model && input_data_type != float_elem_type, "QNN BatchNorm CPU only support float32."); + + std::vector input_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape of input 0."); + const size_t input_rank = input_shape.size(); + + ORT_RETURN_IF(input_rank <= 2 || input_rank > 4, + "QNN BatchNorm only supports input ranks of size 3 or 4."); + + const uint32_t num_channels = input_shape[1]; + + std::vector scale_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, scale_shape), "Cannot get shape of input 1 (scale)."); + ORT_RETURN_IF(scale_shape.size() != 1 || scale_shape[0] != num_channels, + "QNN BatchNorm input 1 (scale) must have 1D shape [channel]."); + + std::vector bias_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[2].node_arg, bias_shape), "Cannot get shape of input 2 (bias)."); + ORT_RETURN_IF(bias_shape.size() != 1 || bias_shape[0] != num_channels, + "QNN BatchNorm input 2 (bias) must have 1D shape [channel]."); + + std::vector mean_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[3].node_arg, mean_shape), "Cannot get shape of input 3 (mean)."); + ORT_RETURN_IF(mean_shape.size() != 1 || mean_shape[0] != num_channels, + "QNN BatchNorm input 3 (mean) must have 1D shape [channel]."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[3].node_arg.Name()), "QNN BatchNorm doesn't support dynamic mean."); + + std::vector var_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[4].node_arg, var_shape), "Cannot get shape of input 4 (var)."); + ORT_RETURN_IF(var_shape.size() != 1 || var_shape[0] != num_channels, + "QNN BatchNorm input 4 (var) must have 1D shape [channel]."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[4].node_arg.Name()), "QNN BatchNorm doesn't support dynamic var."); + + ORT_RETURN_IF(node_unit.Outputs().size() > 1, "QNN BatchNorm only support 1 output."); + } + + return Status::OK(); +} + +void CreateBatchNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 6f8b379184..2adbc3b606 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -215,7 +215,6 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra qnn_model_wrapper.AddParamWrapper(std::move(count_pad_for_edges_param)); } - output_count_ = 1; ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names), diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index c1a4c77d07..a2bc2b60c6 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -204,9 +204,17 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(), logger); + if (supported_nodes.empty()) { LOGS(logger, INFO) << "Number of partitions supported by QNN EP: 0"; return result; + } else if (supported_nodes.size() == 1) { + const auto* node = *supported_nodes.begin(); + if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") { + LOGS(logger, INFO) << "It doesn't make sense just run a Q/DQ node on HTP."; + LOGS(logger, INFO) << "Number of partitions supported by QNN EP: 0"; + return result; + } } const auto gen_metadef_name = [&]() { diff --git a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc b/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc new file mode 100644 index 0000000000..4e7c815014 --- /dev/null +++ b/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include "core/graph/graph.h" + +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +// Creates the graph: +// _______________________ +// input_u8 -> DQ -> | | +// scale_u8 (initializer) -> DQ -> | | +// bias_u8 (initializer) -> DQ -> | BatchNormalization | -> Q -> output_u8 +// mean_u8 (initializer) -> DQ -> | | +// var_u8 (initializer) -> DQ -> |_______________________| +// +// Currently used to test QNN EP. +template +GetQDQTestCaseFn BuildQDQBatchNormTestCase(const std::vector& input_shape) { + return [input_shape](ModelTestBuilder& builder) { + const int64_t num_channels = input_shape[1]; + const InputQType quant_zero_point = 0; + const float quant_scale = 1.0f; + + auto* input = builder.MakeInput(input_shape, static_cast(-1), + static_cast(1)); + auto* dq_input = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input, 0.0039f, quant_zero_point, dq_input); + + auto* dq_scale_output = builder.MakeIntermediate(); + auto* scale = builder.MakeInitializer({num_channels}, static_cast(1), static_cast(127)); + builder.AddDequantizeLinearNode(scale, 0.0028f, quant_zero_point, dq_scale_output); + + auto* dq_bias_output = builder.MakeIntermediate(); + auto* bias = builder.MakeInitializer({num_channels}, static_cast(0), static_cast(0)); + builder.AddDequantizeLinearNode(bias, quant_scale, quant_zero_point, dq_bias_output); + + auto* dq_mean_output = builder.MakeIntermediate(); + auto* mean = builder.MakeInitializer({num_channels}, static_cast(0), static_cast(0)); + builder.AddDequantizeLinearNode(mean, quant_scale, quant_zero_point, dq_mean_output); + + auto* dq_var_output = builder.MakeIntermediate(); + auto* var = builder.MakeInitializer({num_channels}, static_cast(255), static_cast(255)); + builder.AddDequantizeLinearNode(var, 0.003921f, 0, dq_var_output); + + auto* batchnorm_output = builder.MakeIntermediate(); + builder.AddNode("BatchNormalization", {dq_input, dq_scale_output, dq_bias_output, dq_mean_output, dq_var_output}, {batchnorm_output}); + + auto* q_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(batchnorm_output, 0.00377f, quant_zero_point, q_output); + + auto* final_output = builder.MakeOutput(); + builder.AddDequantizeLinearNode(q_output, 0.00377f, + quant_zero_point, + final_output); + }; +} + +/** + * Runs an BatchNormalization model on the QNN HTP backend. Checks the graph node assignment, and that inference + * outputs for QNN and CPU match. + * + * \param input_shape The input's shape. + * \param test_description Description of the test for error reporting. + * \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None). + * \param num_modes_in_graph The number of expected nodes in the graph. + */ +static void RunBatchNormQDQTest(const std::vector& input_shape, const char* test_description, + ExpectedEPNodeAssignment expected_ep_assignment, int num_nodes_in_graph) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + // Runs model with DQ-> InstanceNorm -> Q and compares the outputs of the CPU and QNN EPs. + RunQnnModelTest(BuildQDQBatchNormTestCase(input_shape), + provider_options, + 11, + expected_ep_assignment, + num_nodes_in_graph, + test_description); +} + +// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. +// Use an input of rank 3. +TEST_F(QnnHTPBackendTests, TestQDQBatchNorm1D) { + RunBatchNormQDQTest({1, 2, 3}, "TestQDQBatchNorm1D", ExpectedEPNodeAssignment::All, 1); +} + +// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. +// Use an input of rank 4. +TEST_F(QnnHTPBackendTests, TestQDQBatchNorm2D) { + RunBatchNormQDQTest({2, 3, 4, 5}, "TestQDQBatchNorm2D", ExpectedEPNodeAssignment::All, 1); +} + +// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. +// Use an input of rank 5. QNN BatchNormalization doesn't support 5D on HTP +TEST_F(QnnHTPBackendTests, TestQDQBatchNorm3D) { + RunBatchNormQDQTest({1, 2, 3, 4, 5}, "TestQDQBatchNorm3D", ExpectedEPNodeAssignment::None, 8); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime + +#endif \ No newline at end of file