Removing Double QDQ from Graphs (#14024)

### Description
When there are 2 QDQ pair back to back, we want to delete the 1 Q and 1
DQ nodes.
ex:
Q->DQ->Q->DQ  =====> Q->DQ



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Jian Chen 2023-01-16 22:06:57 -05:00 committed by GitHub
parent fb801d58b1
commit d95249f516
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 394 additions and 9 deletions

View file

@ -47,6 +47,12 @@ static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.se
// Its default value is "0"
static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq";
// It controls whether to enable Double QDQ remover and Identical Children Consolidation
// "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
// "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
// Its default value is "0"
static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover";
// If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been
// completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the
// Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to

View file

@ -0,0 +1,167 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/double_qdq_pairs_remover.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
namespace onnxruntime {
Status DoubleQDQPairsRemover::ApplyImpl(
Graph& graph,
bool& modified,
int /*graph_level*/,
const logging::Logger& /*logger*/) const {
const GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
for (const auto& self_index : node_topology_list) {
NodeIndex parent_index = 0;
NodeIndex child_index = 0;
NodeIndex grandchild_index = 0;
if (IsNodeRemovable(graph, self_index, parent_index, child_index, grandchild_index)) {
graph.RemoveEdge(parent_index, self_index, 0, 0);
graph.RemoveEdge(self_index, child_index, 0, 0);
graph.RemoveEdge(child_index, grandchild_index, 0, 0);
graph_utils::ReplaceNodeInput(*graph.GetNode(grandchild_index), 0, *graph.GetNode(self_index)->MutableInputDefs()[0]);
graph.AddEdge(parent_index, grandchild_index, 0, 0);
graph.RemoveNode(child_index);
graph.RemoveNode(self_index);
modified = true;
}
}
return Status::OK();
}
bool DoubleQDQPairsRemover::IsNodeRemovable(
Graph& graph,
const NodeIndex& self_index,
NodeIndex& parent_index,
NodeIndex& child_index,
NodeIndex& grandchild_index) {
// Check if the self is a DQ, and have one parent and one child, and cannot be a graph output
Node* self = graph.GetNode(self_index);
if (self == nullptr ||
self->OpType() != "DequantizeLinear" ||
self->GetInputEdgesCount() != 1 ||
self->GetOutputEdgesCount() != 1 ||
self->InputDefs().size() != InputIndex::TOTAL_COUNT ||
graph.NodeProducesGraphOutput(*self)) {
return false;
}
// Type is either "tensor(uint8)" or "tensor(int8)"
const auto self_zp_type = *self->InputDefs()[InputIndex::ZERO_POINT_ID]->Type();
// child should be a Q, and have only one child, have the same type as self, and cannot be a graph output
child_index = self->OutputEdgesBegin()->GetNode().Index();
const Node* child = graph.GetNode(child_index);
if (child == nullptr ||
child->OpType() != "QuantizeLinear" ||
child->GetOutputEdgesCount() != 1 ||
child->InputDefs().size() != InputIndex::TOTAL_COUNT ||
*child->InputDefs()[InputIndex::ZERO_POINT_ID]->Type() != self_zp_type ||
graph.NodeProducesGraphOutput(*child)) {
return false;
}
// parent should be a Q, and have only one output, and cannot be a graph output
parent_index = self->InputEdgesBegin()->GetNode().Index();
Node* parent = graph.GetNode(parent_index);
if (parent == nullptr ||
parent->GetOutputEdgesCount() != 1 ||
parent->OpType() != "QuantizeLinear" ||
graph.NodeProducesGraphOutput(*parent)) {
return false;
}
// grandchild should be a DQ
grandchild_index = child->OutputEdgesBegin()->GetNode().Index();
Node* grandchild = graph.GetNode(grandchild_index);
if (grandchild == nullptr ||
grandchild->OpType() != "DequantizeLinear") {
return false;
}
const auto get_constant_initializer = [&graph](const std::string& initializer_name) {
return graph.GetConstantInitializer(initializer_name, true);
};
if (!QDQ::IsQDQPairSupported(*parent, *self, get_constant_initializer, graph.ModelPath()) ||
!QDQ::IsQDQPairSupported(*child, *grandchild, get_constant_initializer, graph.ModelPath())) {
return false;
}
float new_scale = 0.0f;
if (self_zp_type == "tensor(uint8)") {
uint8_t new_zero_point = 0;
if (!FindNewZeroPointAndScale(graph, *self, *child, new_scale, new_zero_point)) {
return false;
}
ApplyNewInputValue(graph, *grandchild, InputIndex::SCALE_ID, new_scale);
ApplyNewInputValue(graph, *parent, InputIndex::SCALE_ID, new_scale);
ApplyNewInputValue(graph, *grandchild, InputIndex::ZERO_POINT_ID, new_zero_point);
ApplyNewInputValue(graph, *parent, InputIndex::ZERO_POINT_ID, new_zero_point);
} else {
int8_t new_zero_point = 0;
if (!FindNewZeroPointAndScale(graph, *self, *child, new_scale, new_zero_point)) {
return false;
}
ApplyNewInputValue(graph, *grandchild, InputIndex::SCALE_ID, new_scale);
ApplyNewInputValue(graph, *parent, InputIndex::SCALE_ID, new_scale);
ApplyNewInputValue(graph, *grandchild, InputIndex::ZERO_POINT_ID, new_zero_point);
ApplyNewInputValue(graph, *parent, InputIndex::ZERO_POINT_ID, new_zero_point);
}
return true;
}
template <typename T>
bool DoubleQDQPairsRemover::FindNewZeroPointAndScale(const Graph& graph, const Node& node1, const Node& node2, float& new_scale, T& new_zero_point) {
// if Q/DQ scale and zero point are not constant, return false
const ONNX_NAMESPACE::TensorProto* node1_scale_tensor_proto =
graph_utils::GetConstantInitializer(graph, node1.InputDefs()[InputIndex::SCALE_ID]->Name());
const ONNX_NAMESPACE::TensorProto* node2_scale_tensor_proto =
graph_utils::GetConstantInitializer(graph, node2.InputDefs()[InputIndex::SCALE_ID]->Name());
const ONNX_NAMESPACE::TensorProto* node1_zp_tensor_proto =
graph_utils::GetConstantInitializer(graph, node1.InputDefs()[InputIndex::ZERO_POINT_ID]->Name());
const ONNX_NAMESPACE::TensorProto* node2_zp_tensor_proto =
graph_utils::GetConstantInitializer(graph, node2.InputDefs()[InputIndex::ZERO_POINT_ID]->Name());
Initializer zero_point_init_1{*node1_zp_tensor_proto, graph.ModelPath()};
Initializer zero_point_init_2{*node2_zp_tensor_proto, graph.ModelPath()};
Initializer scale_init_1{*node1_scale_tensor_proto, graph.ModelPath()};
Initializer scale_init_2{*node2_scale_tensor_proto, graph.ModelPath()};
if (zero_point_init_1.data_type() != zero_point_init_2.data_type() ||
scale_init_1.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
scale_init_2.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
return false;
}
T zero_point_1 = zero_point_init_1.data<T>()[0];
T zero_point_2 = zero_point_init_2.data<T>()[0];
const float scale_1 = scale_init_1.data<float>()[0];
const float scale_2 = scale_init_2.data<float>()[0];
T q_min = std::numeric_limits<T>::min();
T q_max = std::numeric_limits<T>::max();
float real_min1 = gsl::narrow_cast<float>(q_min - zero_point_1) * scale_1;
float real_max1 = gsl::narrow_cast<float>(q_max - zero_point_1) * scale_1;
float real_min2 = gsl::narrow_cast<float>(q_min - zero_point_2) * scale_2;
float real_max2 = gsl::narrow_cast<float>(q_max - zero_point_2) * scale_2;
const float real_min = std::max(real_min1, real_min2);
const float real_max = std::min(real_max1, real_max2);
new_scale = (real_max - real_min) / gsl::narrow_cast<float>(q_max - q_min);
new_zero_point = gsl::narrow_cast<T>(std::round(gsl::narrow_cast<float>(q_min) - real_min / new_scale));
return true;
}
template <typename T>
void DoubleQDQPairsRemover::ApplyNewInputValue(Graph& graph, Node& node, const InputIndex& index, T value) {
const auto* input_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[index]->Name());
Initializer input_init{*input_tensor, graph.ModelPath()};
TensorProto new_input_tensor(*input_tensor);
input_init.data<T>()[0] = value;
input_init.ToProto(new_input_tensor);
auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name());
new_input_tensor.set_name(new_name);
NodeArg& new_input = graph_utils::AddInitializer(graph, new_input_tensor);
graph_utils::ReplaceNodeInput(node, index, new_input);
}
} // namespace onnxruntime

View file

@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/qdq_transformer/qdq_util.h"
namespace onnxruntime {
using ONNX_NAMESPACE::TensorProto;
using ONNX_NAMESPACE::TensorProto_DataType;
using QDQ::InputIndex;
/**
* @Class DoubleQDQPairsRemover
* @brief Remove one pair of Q-DQ from Double Q-DQ pairs.
*/
class DoubleQDQPairsRemover : public GraphTransformer {
public:
DoubleQDQPairsRemover() : GraphTransformer("DoubleQDQPairsRemover", {}) {}
private:
Status ApplyImpl(
Graph& graph,
bool& modified,
int graph_level,
const logging::Logger& logger) const override;
static bool IsNodeRemovable(
Graph& graph,
const NodeIndex& self_index,
NodeIndex& parent_index,
NodeIndex& child_index,
NodeIndex& grandchild_index);
template <typename T>
static bool FindNewZeroPointAndScale(
const Graph& graph,
const Node& node1,
const Node& node2,
float& new_scale,
T& new_zero_point);
template <typename T>
static void ApplyNewInputValue(
Graph& graph,
Node& node,
const InputIndex& index,
T value);
};
} // namespace onnxruntime

View file

@ -33,6 +33,7 @@
#include "core/optimizer/conv_bn_fusion.h"
#include "core/optimizer/conv_mul_fusion.h"
#include "core/optimizer/div_mul_fusion.h"
#include "core/optimizer/double_qdq_pairs_remover.h"
#include "core/optimizer/dropout_elimination.h"
#include "core/optimizer/dynamic_quantize_matmul_fusion.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
@ -45,6 +46,7 @@
#include "core/optimizer/gemm_activation_fusion.h"
#include "core/optimizer/gemm_sum_fusion.h"
#include "core/optimizer/gemm_transpose_fusion.h"
#include "core/optimizer/identical_children_consolidation.h"
#include "core/optimizer/identity_elimination.h"
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/matmul_add_fusion.h"
@ -66,7 +68,6 @@
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/transpose_optimizer/ort_transpose_optimizer.h"
#include "core/optimizer/unsqueeze_elimination.h"
#include "core/optimizer/identical_children_consolidation.h"
#ifdef ENABLE_TRAINING_CORE
#include "orttraining/core/optimizer/bitmask_dropout_replacement.h"
#include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h"
@ -191,12 +192,17 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::move(rule_transformer));
}
// We need to remove the duplicated QDQ Pairs before all other GraphTransformation.
// no filtering on execution provider for L1 optimizations as they only use official ONNX operators
// Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for
// CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by
// default, CSE will not merge them, because the different initializers are represented by different NodeArg.
transformers.emplace_back(std::make_unique<IdenticalChildrenConsolidation>());
if (session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableDoubleQDQRemover, "0") == "0"){
transformers.emplace_back(std::make_unique<IdenticalChildrenConsolidation>());
transformers.emplace_back(std::make_unique<DoubleQDQPairsRemover>());
}
transformers.emplace_back(std::make_unique<ConstantSharing>());
transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>());
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq));

View file

@ -3555,6 +3555,30 @@ static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_opt
EXPECT_EQ(has_gelu_approximation, is_enabled);
}
// Test session option configuration for DoubleQDQPairsRemover
TEST_F(GraphTransformationTests, DoubleQDQRemover_SessionOptionConfig) {
auto verify_session_config = [&](bool is_enabled, SessionOptions& session_option) {
std::unique_ptr<CPUExecutionProvider> cpu_ep = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
bool has_double_qdq_remover = false;
auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(), {});
for (auto& transformer : transformers) {
if (transformer->Name() == "DoubleQDQPairsRemover") {
has_double_qdq_remover = true;
}
}
EXPECT_EQ(has_double_qdq_remover, is_enabled);
};
SessionOptions session_options;
// DoubleQDQPairsRemover is enabled by default.
verify_session_config(true, session_options);
ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableDoubleQDQRemover, "1"));
verify_session_config(false, session_options);
ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableDoubleQDQRemover, "0"));
verify_session_config(true, session_options);
}
// Test session option configuration for GeluApproximation
TEST_F(GraphTransformationTests, GeluApproximation_SessionOptionConfig) {
SessionOptions session_options;

View file

@ -395,6 +395,46 @@ GetQDQTestCaseFn BuildConsolidationTestCase(
};
}
template <typename Type1, typename Type2, typename Type3, typename Type4>
GetQDQTestCaseFn BuildDoubleQDQTestCases(Type1 zp_1, Type2 zp_2, Type3 zp_3, Type4 zp_4,
float scale_1, float scale_2, float scale_3, float scale_4) {
return [=](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>(
{11, 22, 33, 44},
std::numeric_limits<Type1>::min() * (scale_1 + scale_3) / 2,
std::numeric_limits<Type1>::max() * (scale_1 + scale_3) / 2);
NodeArg* q1_output = builder.MakeIntermediate();
NodeArg* dq1_output = builder.MakeIntermediate();
NodeArg* q2_output = builder.MakeIntermediate();
NodeArg* dq2_output = builder.MakeOutput();
builder.AddQuantizeLinearNode<Type1>(input_arg, scale_1, zp_1, q1_output);
builder.AddDequantizeLinearNode<Type2>(q1_output, scale_2, zp_2, dq1_output);
builder.AddQuantizeLinearNode<Type3>(dq1_output, scale_3, zp_3, q2_output);
builder.AddDequantizeLinearNode<Type4>(q2_output, scale_4, zp_4, dq2_output);
};
}
template <typename T>
GetQDQTestCaseFn BuildDoubleQDQWithoutLastOutput(int output_index) {
return [=](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>({2, 3, 4}, std::numeric_limits<float>::min(), std::numeric_limits<float>::max());
T zp = (std::numeric_limits<T>::max() - std::numeric_limits<T>::min()) / 2;
float scale = 0.003f;
std::vector<NodeArg*> outputs(4);
for (auto i = 0; i < 4; i++) {
if (output_index == i) {
outputs[i] = builder.MakeOutput();
} else {
outputs[i] = builder.MakeIntermediate();
}
}
builder.AddQuantizeLinearNode<T>(input_arg, scale, zp, outputs[0]);
builder.AddDequantizeLinearNode<T>(outputs[0], scale, zp, outputs[1]);
builder.AddQuantizeLinearNode<T>(outputs[1], scale, zp, outputs[2]);
builder.AddDequantizeLinearNode<T>(outputs[2], scale, zp, outputs[3]);
};
}
template <typename InputType, typename OutputType>
GetQDQTestCaseFn BuildQDQSplitTestCase(
const std::vector<int64_t>& input_shape,

View file

@ -9,6 +9,7 @@
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
#include "core/optimizer/utils.h"
#include "core/providers/partitioning_utils.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/environment.h"
@ -764,6 +765,99 @@ TEST(QDQTransformerTests, Gather) {
test_case({12, 37}, {24, 12});
}
TEST(QDQTransformerTests, DoubleQDQ) {
constexpr uint8_t good_u8_1 = 80;
constexpr uint8_t good_u8_2 = 40;
constexpr uint8_t bad_u8 = 13;
constexpr int8_t good_s8_1 = 99;
constexpr int8_t good_s8_2 = -112;
constexpr int8_t bad_s8 = 42;
constexpr float good_float_point_1 = 4.0f;
constexpr float good_float_point_2 = 8.0f;
constexpr float bad_float_point = 1.11f;
std::function<void(InferenceSessionWrapper & session)> expect_succeed = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["QuantizeLinear"], 1);
EXPECT_EQ(op_to_count["DequantizeLinear"], 1);
};
std::function<void(InferenceSessionWrapper & session)> expect_fail = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["QuantizeLinear"], 2);
EXPECT_EQ(op_to_count["DequantizeLinear"], 2);
};
auto test_case_all_u8 = [&](bool succeed,
uint8_t zp_1, uint8_t zp_2, uint8_t zp_3, uint8_t zp_4,
float scale_1, float scale_2, float scale_3, float scale_4) {
TransformerTester(
BuildDoubleQDQTestCases<uint8_t, uint8_t, uint8_t, uint8_t>(zp_1, zp_2, zp_3, zp_4, scale_1, scale_2, scale_3, scale_4),
succeed ? expect_succeed : expect_fail,
TransformerLevel::Default,
TransformerLevel::Level1,
12,
(scale_1 + scale_3) / 2,
0.01);
};
auto test_case_all_s8 = [&](bool succeed,
int8_t zp_1, int8_t zp_2, int8_t zp_3, int8_t zp_4,
float scale_1, float scale_2, float scale_3, float scale_4) {
TransformerTester(
BuildDoubleQDQTestCases<int8_t, int8_t, int8_t, int8_t>(zp_1, zp_2, zp_3, zp_4, scale_1, scale_2, scale_3, scale_4),
succeed ? expect_succeed : expect_fail,
TransformerLevel::Default,
TransformerLevel::Level1,
12,
(scale_1 + scale_3) / 2,
0.01);
};
auto test_case_2u8_2s8_failed = [&](uint8_t zp_1, uint8_t zp_2, int8_t zp_3, int8_t zp_4,
float scale_1, float scale_2, float scale_3, float scale_4) {
TransformerTester(
BuildDoubleQDQTestCases<uint8_t, uint8_t, int8_t, int8_t>(zp_1, zp_2, zp_3, zp_4, scale_1, scale_2, scale_3, scale_4),
expect_fail,
TransformerLevel::Default,
TransformerLevel::Level1);
};
// all unsigned type
test_case_all_u8(true, good_u8_1, good_u8_1, good_u8_2, good_u8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2);
// all signed type
test_case_all_s8(true, good_s8_1, good_s8_1, good_s8_2, good_s8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2);
// 2 signed, 2 unsigned
test_case_2u8_2s8_failed(good_u8_1, good_u8_1, good_s8_2, good_s8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2);
// different zero point within a pair
test_case_all_u8(false, good_u8_1, bad_u8, good_u8_2, good_u8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2);
test_case_all_u8(false, good_u8_1, good_u8_1, good_u8_2, bad_u8, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2);
test_case_all_s8(false, good_s8_1, bad_s8, good_s8_2, good_s8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2);
test_case_all_s8(false, good_s8_1, good_s8_1, good_s8_2, bad_s8, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2);
// different scale within a pair
test_case_all_u8(false, good_u8_1, good_u8_1, good_u8_2, good_u8_2, good_float_point_1, bad_float_point, good_float_point_2, good_float_point_2);
test_case_all_u8(false, good_u8_1, good_u8_1, good_u8_2, good_u8_2, good_float_point_1, good_float_point_1, bad_float_point, good_float_point_2);
}
TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) {
auto test_case = [&](int output_index, int expected_Q_count, int expected_DQ_count) {
auto graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["QuantizeLinear"], expected_Q_count);
EXPECT_EQ(op_to_count["DequantizeLinear"], expected_DQ_count);
};
TransformerTester(
BuildDoubleQDQWithoutLastOutput<uint8_t>(output_index),
graph,
TransformerLevel::Default,
TransformerLevel::Level1);
};
test_case(0, 2, 2);
test_case(1, 2, 2);
test_case(2, 2, 2);
test_case(3, 1, 1);
}
// Because split isn't one the supported ops, this will stay the same
TEST(QDQTransformerTests, Split) {
auto test_case = [&](const std::vector<int64_t>& input_shape, const int64_t& axis) {
@ -2585,13 +2679,8 @@ TEST(QDQTransformerTests, QDQFinalCleanupTransformer_BasicDQQCleanUp) {
auto check_graph = [&](const InferenceSessionWrapper& session) {
const auto ops_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph());
const auto expected_ops_in_order = [&]() -> std::vector<std::string> {
if (use_matching_qdq_params) {
// DQ/Q cleanup removes middle DQ/Q
return {"QuantizeLinear", "DequantizeLinear"};
}
// removes nothing
return {"QuantizeLinear", "DequantizeLinear", "QuantizeLinear", "DequantizeLinear"};
// In either case both DQ and Q will be removed and fused due to DoubleQDQPairsRemover
return {"QuantizeLinear", "DequantizeLinear"};
}();
EXPECT_EQ(ops_in_order, expected_ops_in_order);