From d95249f516e83a7f4465604e8ef1c4f2ed02635c Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Mon, 16 Jan 2023 22:06:57 -0500 Subject: [PATCH] Removing Double QDQ from Graphs (#14024) ### Description When there are 2 QDQ pair back to back, we want to delete the 1 Q and 1 DQ nodes. ex: Q->DQ->Q->DQ =====> Q->DQ ### Motivation and Context --- .../onnxruntime_session_options_config_keys.h | 6 + .../optimizer/double_qdq_pairs_remover.cc | 167 ++++++++++++++++++ .../core/optimizer/double_qdq_pairs_remover.h | 53 ++++++ .../core/optimizer/graph_transformer_utils.cc | 10 +- .../test/optimizer/graph_transform_test.cc | 24 +++ onnxruntime/test/optimizer/qdq_test_utils.h | 40 +++++ .../test/optimizer/qdq_transformer_test.cc | 103 ++++++++++- 7 files changed, 394 insertions(+), 9 deletions(-) create mode 100644 onnxruntime/core/optimizer/double_qdq_pairs_remover.cc create mode 100644 onnxruntime/core/optimizer/double_qdq_pairs_remover.h diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 9384435cc6..92482d71f6 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -47,6 +47,12 @@ static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.se // Its default value is "0" static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq"; +// It controls whether to enable Double QDQ remover and Identical Children Consolidation +// "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs +// "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs +// Its default value is "0" +static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover"; + // If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been // completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the // Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc new file mode 100644 index 0000000000..8dd446d82b --- /dev/null +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/optimizer/double_qdq_pairs_remover.h" + +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" + +namespace onnxruntime { + +Status DoubleQDQPairsRemover::ApplyImpl( + Graph& graph, + bool& modified, + int /*graph_level*/, + const logging::Logger& /*logger*/) const { + const GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (const auto& self_index : node_topology_list) { + NodeIndex parent_index = 0; + NodeIndex child_index = 0; + NodeIndex grandchild_index = 0; + if (IsNodeRemovable(graph, self_index, parent_index, child_index, grandchild_index)) { + graph.RemoveEdge(parent_index, self_index, 0, 0); + graph.RemoveEdge(self_index, child_index, 0, 0); + graph.RemoveEdge(child_index, grandchild_index, 0, 0); + graph_utils::ReplaceNodeInput(*graph.GetNode(grandchild_index), 0, *graph.GetNode(self_index)->MutableInputDefs()[0]); + graph.AddEdge(parent_index, grandchild_index, 0, 0); + graph.RemoveNode(child_index); + graph.RemoveNode(self_index); + modified = true; + } + } + return Status::OK(); +} + +bool DoubleQDQPairsRemover::IsNodeRemovable( + Graph& graph, + const NodeIndex& self_index, + NodeIndex& parent_index, + NodeIndex& child_index, + NodeIndex& grandchild_index) { + // Check if the self is a DQ, and have one parent and one child, and cannot be a graph output + Node* self = graph.GetNode(self_index); + if (self == nullptr || + self->OpType() != "DequantizeLinear" || + self->GetInputEdgesCount() != 1 || + self->GetOutputEdgesCount() != 1 || + self->InputDefs().size() != InputIndex::TOTAL_COUNT || + graph.NodeProducesGraphOutput(*self)) { + return false; + } + + // Type is either "tensor(uint8)" or "tensor(int8)" + const auto self_zp_type = *self->InputDefs()[InputIndex::ZERO_POINT_ID]->Type(); + // child should be a Q, and have only one child, have the same type as self, and cannot be a graph output + child_index = self->OutputEdgesBegin()->GetNode().Index(); + const Node* child = graph.GetNode(child_index); + if (child == nullptr || + child->OpType() != "QuantizeLinear" || + child->GetOutputEdgesCount() != 1 || + child->InputDefs().size() != InputIndex::TOTAL_COUNT || + *child->InputDefs()[InputIndex::ZERO_POINT_ID]->Type() != self_zp_type || + graph.NodeProducesGraphOutput(*child)) { + return false; + } + + // parent should be a Q, and have only one output, and cannot be a graph output + parent_index = self->InputEdgesBegin()->GetNode().Index(); + Node* parent = graph.GetNode(parent_index); + if (parent == nullptr || + parent->GetOutputEdgesCount() != 1 || + parent->OpType() != "QuantizeLinear" || + graph.NodeProducesGraphOutput(*parent)) { + return false; + } + + // grandchild should be a DQ + grandchild_index = child->OutputEdgesBegin()->GetNode().Index(); + Node* grandchild = graph.GetNode(grandchild_index); + if (grandchild == nullptr || + grandchild->OpType() != "DequantizeLinear") { + return false; + } + const auto get_constant_initializer = [&graph](const std::string& initializer_name) { + return graph.GetConstantInitializer(initializer_name, true); + }; + if (!QDQ::IsQDQPairSupported(*parent, *self, get_constant_initializer, graph.ModelPath()) || + !QDQ::IsQDQPairSupported(*child, *grandchild, get_constant_initializer, graph.ModelPath())) { + return false; + } + float new_scale = 0.0f; + if (self_zp_type == "tensor(uint8)") { + uint8_t new_zero_point = 0; + if (!FindNewZeroPointAndScale(graph, *self, *child, new_scale, new_zero_point)) { + return false; + } + ApplyNewInputValue(graph, *grandchild, InputIndex::SCALE_ID, new_scale); + ApplyNewInputValue(graph, *parent, InputIndex::SCALE_ID, new_scale); + ApplyNewInputValue(graph, *grandchild, InputIndex::ZERO_POINT_ID, new_zero_point); + ApplyNewInputValue(graph, *parent, InputIndex::ZERO_POINT_ID, new_zero_point); + } else { + int8_t new_zero_point = 0; + if (!FindNewZeroPointAndScale(graph, *self, *child, new_scale, new_zero_point)) { + return false; + } + ApplyNewInputValue(graph, *grandchild, InputIndex::SCALE_ID, new_scale); + ApplyNewInputValue(graph, *parent, InputIndex::SCALE_ID, new_scale); + ApplyNewInputValue(graph, *grandchild, InputIndex::ZERO_POINT_ID, new_zero_point); + ApplyNewInputValue(graph, *parent, InputIndex::ZERO_POINT_ID, new_zero_point); + } + return true; +} + +template +bool DoubleQDQPairsRemover::FindNewZeroPointAndScale(const Graph& graph, const Node& node1, const Node& node2, float& new_scale, T& new_zero_point) { + // if Q/DQ scale and zero point are not constant, return false + const ONNX_NAMESPACE::TensorProto* node1_scale_tensor_proto = + graph_utils::GetConstantInitializer(graph, node1.InputDefs()[InputIndex::SCALE_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* node2_scale_tensor_proto = + graph_utils::GetConstantInitializer(graph, node2.InputDefs()[InputIndex::SCALE_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* node1_zp_tensor_proto = + graph_utils::GetConstantInitializer(graph, node1.InputDefs()[InputIndex::ZERO_POINT_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* node2_zp_tensor_proto = + graph_utils::GetConstantInitializer(graph, node2.InputDefs()[InputIndex::ZERO_POINT_ID]->Name()); + Initializer zero_point_init_1{*node1_zp_tensor_proto, graph.ModelPath()}; + Initializer zero_point_init_2{*node2_zp_tensor_proto, graph.ModelPath()}; + Initializer scale_init_1{*node1_scale_tensor_proto, graph.ModelPath()}; + Initializer scale_init_2{*node2_scale_tensor_proto, graph.ModelPath()}; + if (zero_point_init_1.data_type() != zero_point_init_2.data_type() || + scale_init_1.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT || + scale_init_2.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return false; + } + + T zero_point_1 = zero_point_init_1.data()[0]; + T zero_point_2 = zero_point_init_2.data()[0]; + const float scale_1 = scale_init_1.data()[0]; + const float scale_2 = scale_init_2.data()[0]; + T q_min = std::numeric_limits::min(); + T q_max = std::numeric_limits::max(); + + float real_min1 = gsl::narrow_cast(q_min - zero_point_1) * scale_1; + float real_max1 = gsl::narrow_cast(q_max - zero_point_1) * scale_1; + float real_min2 = gsl::narrow_cast(q_min - zero_point_2) * scale_2; + float real_max2 = gsl::narrow_cast(q_max - zero_point_2) * scale_2; + + const float real_min = std::max(real_min1, real_min2); + const float real_max = std::min(real_max1, real_max2); + + new_scale = (real_max - real_min) / gsl::narrow_cast(q_max - q_min); + new_zero_point = gsl::narrow_cast(std::round(gsl::narrow_cast(q_min) - real_min / new_scale)); + return true; +} + +template +void DoubleQDQPairsRemover::ApplyNewInputValue(Graph& graph, Node& node, const InputIndex& index, T value) { + const auto* input_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[index]->Name()); + Initializer input_init{*input_tensor, graph.ModelPath()}; + TensorProto new_input_tensor(*input_tensor); + input_init.data()[0] = value; + input_init.ToProto(new_input_tensor); + auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name()); + new_input_tensor.set_name(new_name); + NodeArg& new_input = graph_utils::AddInitializer(graph, new_input_tensor); + graph_utils::ReplaceNodeInput(node, index, new_input); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.h b/onnxruntime/core/optimizer/double_qdq_pairs_remover.h new file mode 100644 index 0000000000..294cd842d4 --- /dev/null +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/optimizer/graph_transformer.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" + +namespace onnxruntime { + +using ONNX_NAMESPACE::TensorProto; +using ONNX_NAMESPACE::TensorProto_DataType; +using QDQ::InputIndex; + +/** + * @Class DoubleQDQPairsRemover + * @brief Remove one pair of Q-DQ from Double Q-DQ pairs. + */ +class DoubleQDQPairsRemover : public GraphTransformer { + public: + DoubleQDQPairsRemover() : GraphTransformer("DoubleQDQPairsRemover", {}) {} + + private: + Status ApplyImpl( + Graph& graph, + bool& modified, + int graph_level, + const logging::Logger& logger) const override; + + static bool IsNodeRemovable( + Graph& graph, + const NodeIndex& self_index, + NodeIndex& parent_index, + NodeIndex& child_index, + NodeIndex& grandchild_index); + + template + static bool FindNewZeroPointAndScale( + const Graph& graph, + const Node& node1, + const Node& node2, + float& new_scale, + T& new_zero_point); + + template + static void ApplyNewInputValue( + Graph& graph, + Node& node, + const InputIndex& index, + T value); +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index fe211b0681..fdee3c19f2 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -33,6 +33,7 @@ #include "core/optimizer/conv_bn_fusion.h" #include "core/optimizer/conv_mul_fusion.h" #include "core/optimizer/div_mul_fusion.h" +#include "core/optimizer/double_qdq_pairs_remover.h" #include "core/optimizer/dropout_elimination.h" #include "core/optimizer/dynamic_quantize_matmul_fusion.h" #include "core/optimizer/embed_layer_norm_fusion.h" @@ -45,6 +46,7 @@ #include "core/optimizer/gemm_activation_fusion.h" #include "core/optimizer/gemm_sum_fusion.h" #include "core/optimizer/gemm_transpose_fusion.h" +#include "core/optimizer/identical_children_consolidation.h" #include "core/optimizer/identity_elimination.h" #include "core/optimizer/layer_norm_fusion.h" #include "core/optimizer/matmul_add_fusion.h" @@ -66,7 +68,6 @@ #include "core/optimizer/slice_elimination.h" #include "core/optimizer/transpose_optimizer/ort_transpose_optimizer.h" #include "core/optimizer/unsqueeze_elimination.h" -#include "core/optimizer/identical_children_consolidation.h" #ifdef ENABLE_TRAINING_CORE #include "orttraining/core/optimizer/bitmask_dropout_replacement.h" #include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h" @@ -191,12 +192,17 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::move(rule_transformer)); } + // We need to remove the duplicated QDQ Pairs before all other GraphTransformation. + // no filtering on execution provider for L1 optimizations as they only use official ONNX operators // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by // default, CSE will not merge them, because the different initializers are represented by different NodeArg. - transformers.emplace_back(std::make_unique()); + if (session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableDoubleQDQRemover, "0") == "0"){ + transformers.emplace_back(std::make_unique()); + transformers.emplace_back(std::make_unique()); + } transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq)); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index d1d8f435b3..b37953b508 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -3555,6 +3555,30 @@ static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_opt EXPECT_EQ(has_gelu_approximation, is_enabled); } +// Test session option configuration for DoubleQDQPairsRemover +TEST_F(GraphTransformationTests, DoubleQDQRemover_SessionOptionConfig) { + auto verify_session_config = [&](bool is_enabled, SessionOptions& session_option) { + std::unique_ptr cpu_ep = std::make_unique(CPUExecutionProviderInfo()); + bool has_double_qdq_remover = false; + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(), {}); + for (auto& transformer : transformers) { + if (transformer->Name() == "DoubleQDQPairsRemover") { + has_double_qdq_remover = true; + } + } + EXPECT_EQ(has_double_qdq_remover, is_enabled); + }; + SessionOptions session_options; + // DoubleQDQPairsRemover is enabled by default. + verify_session_config(true, session_options); + + ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableDoubleQDQRemover, "1")); + verify_session_config(false, session_options); + + ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableDoubleQDQRemover, "0")); + verify_session_config(true, session_options); +} + // Test session option configuration for GeluApproximation TEST_F(GraphTransformationTests, GeluApproximation_SessionOptionConfig) { SessionOptions session_options; diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 9f98dcd691..cb19a1e69e 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -395,6 +395,46 @@ GetQDQTestCaseFn BuildConsolidationTestCase( }; } +template +GetQDQTestCaseFn BuildDoubleQDQTestCases(Type1 zp_1, Type2 zp_2, Type3 zp_3, Type4 zp_4, + float scale_1, float scale_2, float scale_3, float scale_4) { + return [=](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput( + {11, 22, 33, 44}, + std::numeric_limits::min() * (scale_1 + scale_3) / 2, + std::numeric_limits::max() * (scale_1 + scale_3) / 2); + NodeArg* q1_output = builder.MakeIntermediate(); + NodeArg* dq1_output = builder.MakeIntermediate(); + NodeArg* q2_output = builder.MakeIntermediate(); + NodeArg* dq2_output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(input_arg, scale_1, zp_1, q1_output); + builder.AddDequantizeLinearNode(q1_output, scale_2, zp_2, dq1_output); + builder.AddQuantizeLinearNode(dq1_output, scale_3, zp_3, q2_output); + builder.AddDequantizeLinearNode(q2_output, scale_4, zp_4, dq2_output); + }; +} + +template +GetQDQTestCaseFn BuildDoubleQDQWithoutLastOutput(int output_index) { + return [=](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({2, 3, 4}, std::numeric_limits::min(), std::numeric_limits::max()); + T zp = (std::numeric_limits::max() - std::numeric_limits::min()) / 2; + float scale = 0.003f; + std::vector outputs(4); + for (auto i = 0; i < 4; i++) { + if (output_index == i) { + outputs[i] = builder.MakeOutput(); + } else { + outputs[i] = builder.MakeIntermediate(); + } + } + builder.AddQuantizeLinearNode(input_arg, scale, zp, outputs[0]); + builder.AddDequantizeLinearNode(outputs[0], scale, zp, outputs[1]); + builder.AddQuantizeLinearNode(outputs[1], scale, zp, outputs[2]); + builder.AddDequantizeLinearNode(outputs[2], scale, zp, outputs[3]); + }; +} + template GetQDQTestCaseFn BuildQDQSplitTestCase( const std::vector& input_shape, diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 749f6fcf7b..039bce599a 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -9,6 +9,7 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/optimizer/utils.h" #include "core/providers/partitioning_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/environment.h" @@ -764,6 +765,99 @@ TEST(QDQTransformerTests, Gather) { test_case({12, 37}, {24, 12}); } +TEST(QDQTransformerTests, DoubleQDQ) { + constexpr uint8_t good_u8_1 = 80; + constexpr uint8_t good_u8_2 = 40; + constexpr uint8_t bad_u8 = 13; + + constexpr int8_t good_s8_1 = 99; + constexpr int8_t good_s8_2 = -112; + constexpr int8_t bad_s8 = 42; + + constexpr float good_float_point_1 = 4.0f; + constexpr float good_float_point_2 = 8.0f; + constexpr float bad_float_point = 1.11f; + + std::function expect_succeed = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["QuantizeLinear"], 1); + EXPECT_EQ(op_to_count["DequantizeLinear"], 1); + }; + std::function expect_fail = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["QuantizeLinear"], 2); + EXPECT_EQ(op_to_count["DequantizeLinear"], 2); + }; + + auto test_case_all_u8 = [&](bool succeed, + uint8_t zp_1, uint8_t zp_2, uint8_t zp_3, uint8_t zp_4, + float scale_1, float scale_2, float scale_3, float scale_4) { + TransformerTester( + BuildDoubleQDQTestCases(zp_1, zp_2, zp_3, zp_4, scale_1, scale_2, scale_3, scale_4), + succeed ? expect_succeed : expect_fail, + TransformerLevel::Default, + TransformerLevel::Level1, + 12, + (scale_1 + scale_3) / 2, + 0.01); + }; + + auto test_case_all_s8 = [&](bool succeed, + int8_t zp_1, int8_t zp_2, int8_t zp_3, int8_t zp_4, + float scale_1, float scale_2, float scale_3, float scale_4) { + TransformerTester( + BuildDoubleQDQTestCases(zp_1, zp_2, zp_3, zp_4, scale_1, scale_2, scale_3, scale_4), + succeed ? expect_succeed : expect_fail, + TransformerLevel::Default, + TransformerLevel::Level1, + 12, + (scale_1 + scale_3) / 2, + 0.01); + }; + + auto test_case_2u8_2s8_failed = [&](uint8_t zp_1, uint8_t zp_2, int8_t zp_3, int8_t zp_4, + float scale_1, float scale_2, float scale_3, float scale_4) { + TransformerTester( + BuildDoubleQDQTestCases(zp_1, zp_2, zp_3, zp_4, scale_1, scale_2, scale_3, scale_4), + expect_fail, + TransformerLevel::Default, + TransformerLevel::Level1); + }; + + // all unsigned type + test_case_all_u8(true, good_u8_1, good_u8_1, good_u8_2, good_u8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2); + // all signed type + test_case_all_s8(true, good_s8_1, good_s8_1, good_s8_2, good_s8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2); + // 2 signed, 2 unsigned + test_case_2u8_2s8_failed(good_u8_1, good_u8_1, good_s8_2, good_s8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2); + // different zero point within a pair + test_case_all_u8(false, good_u8_1, bad_u8, good_u8_2, good_u8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2); + test_case_all_u8(false, good_u8_1, good_u8_1, good_u8_2, bad_u8, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2); + test_case_all_s8(false, good_s8_1, bad_s8, good_s8_2, good_s8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2); + test_case_all_s8(false, good_s8_1, good_s8_1, good_s8_2, bad_s8, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2); + // different scale within a pair + test_case_all_u8(false, good_u8_1, good_u8_1, good_u8_2, good_u8_2, good_float_point_1, bad_float_point, good_float_point_2, good_float_point_2); + test_case_all_u8(false, good_u8_1, good_u8_1, good_u8_2, good_u8_2, good_float_point_1, good_float_point_1, bad_float_point, good_float_point_2); +} + +TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) { + auto test_case = [&](int output_index, int expected_Q_count, int expected_DQ_count) { + auto graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["QuantizeLinear"], expected_Q_count); + EXPECT_EQ(op_to_count["DequantizeLinear"], expected_DQ_count); + }; + TransformerTester( + BuildDoubleQDQWithoutLastOutput(output_index), + graph, + TransformerLevel::Default, + TransformerLevel::Level1); + }; + test_case(0, 2, 2); + test_case(1, 2, 2); + test_case(2, 2, 2); + test_case(3, 1, 1); +} // Because split isn't one the supported ops, this will stay the same TEST(QDQTransformerTests, Split) { auto test_case = [&](const std::vector& input_shape, const int64_t& axis) { @@ -2585,13 +2679,8 @@ TEST(QDQTransformerTests, QDQFinalCleanupTransformer_BasicDQQCleanUp) { auto check_graph = [&](const InferenceSessionWrapper& session) { const auto ops_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph()); const auto expected_ops_in_order = [&]() -> std::vector { - if (use_matching_qdq_params) { - // DQ/Q cleanup removes middle DQ/Q - return {"QuantizeLinear", "DequantizeLinear"}; - } - - // removes nothing - return {"QuantizeLinear", "DequantizeLinear", "QuantizeLinear", "DequantizeLinear"}; + // In either case both DQ and Q will be removed and fused due to DoubleQDQPairsRemover + return {"QuantizeLinear", "DequantizeLinear"}; }(); EXPECT_EQ(ops_in_order, expected_ops_in_order);