mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Removing Double QDQ from Graphs (#14024)
### Description When there are 2 QDQ pair back to back, we want to delete the 1 Q and 1 DQ nodes. ex: Q->DQ->Q->DQ =====> Q->DQ ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
fb801d58b1
commit
d95249f516
7 changed files with 394 additions and 9 deletions
|
|
@ -47,6 +47,12 @@ static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.se
|
|||
// Its default value is "0"
|
||||
static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq";
|
||||
|
||||
// It controls whether to enable Double QDQ remover and Identical Children Consolidation
|
||||
// "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
|
||||
// "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
|
||||
// Its default value is "0"
|
||||
static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover";
|
||||
|
||||
// If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been
|
||||
// completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the
|
||||
// Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to
|
||||
|
|
|
|||
167
onnxruntime/core/optimizer/double_qdq_pairs_remover.cc
Normal file
167
onnxruntime/core/optimizer/double_qdq_pairs_remover.cc
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include "core/optimizer/double_qdq_pairs_remover.h"
|
||||
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status DoubleQDQPairsRemover::ApplyImpl(
|
||||
Graph& graph,
|
||||
bool& modified,
|
||||
int /*graph_level*/,
|
||||
const logging::Logger& /*logger*/) const {
|
||||
const GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
for (const auto& self_index : node_topology_list) {
|
||||
NodeIndex parent_index = 0;
|
||||
NodeIndex child_index = 0;
|
||||
NodeIndex grandchild_index = 0;
|
||||
if (IsNodeRemovable(graph, self_index, parent_index, child_index, grandchild_index)) {
|
||||
graph.RemoveEdge(parent_index, self_index, 0, 0);
|
||||
graph.RemoveEdge(self_index, child_index, 0, 0);
|
||||
graph.RemoveEdge(child_index, grandchild_index, 0, 0);
|
||||
graph_utils::ReplaceNodeInput(*graph.GetNode(grandchild_index), 0, *graph.GetNode(self_index)->MutableInputDefs()[0]);
|
||||
graph.AddEdge(parent_index, grandchild_index, 0, 0);
|
||||
graph.RemoveNode(child_index);
|
||||
graph.RemoveNode(self_index);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool DoubleQDQPairsRemover::IsNodeRemovable(
|
||||
Graph& graph,
|
||||
const NodeIndex& self_index,
|
||||
NodeIndex& parent_index,
|
||||
NodeIndex& child_index,
|
||||
NodeIndex& grandchild_index) {
|
||||
// Check if the self is a DQ, and have one parent and one child, and cannot be a graph output
|
||||
Node* self = graph.GetNode(self_index);
|
||||
if (self == nullptr ||
|
||||
self->OpType() != "DequantizeLinear" ||
|
||||
self->GetInputEdgesCount() != 1 ||
|
||||
self->GetOutputEdgesCount() != 1 ||
|
||||
self->InputDefs().size() != InputIndex::TOTAL_COUNT ||
|
||||
graph.NodeProducesGraphOutput(*self)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Type is either "tensor(uint8)" or "tensor(int8)"
|
||||
const auto self_zp_type = *self->InputDefs()[InputIndex::ZERO_POINT_ID]->Type();
|
||||
// child should be a Q, and have only one child, have the same type as self, and cannot be a graph output
|
||||
child_index = self->OutputEdgesBegin()->GetNode().Index();
|
||||
const Node* child = graph.GetNode(child_index);
|
||||
if (child == nullptr ||
|
||||
child->OpType() != "QuantizeLinear" ||
|
||||
child->GetOutputEdgesCount() != 1 ||
|
||||
child->InputDefs().size() != InputIndex::TOTAL_COUNT ||
|
||||
*child->InputDefs()[InputIndex::ZERO_POINT_ID]->Type() != self_zp_type ||
|
||||
graph.NodeProducesGraphOutput(*child)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// parent should be a Q, and have only one output, and cannot be a graph output
|
||||
parent_index = self->InputEdgesBegin()->GetNode().Index();
|
||||
Node* parent = graph.GetNode(parent_index);
|
||||
if (parent == nullptr ||
|
||||
parent->GetOutputEdgesCount() != 1 ||
|
||||
parent->OpType() != "QuantizeLinear" ||
|
||||
graph.NodeProducesGraphOutput(*parent)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// grandchild should be a DQ
|
||||
grandchild_index = child->OutputEdgesBegin()->GetNode().Index();
|
||||
Node* grandchild = graph.GetNode(grandchild_index);
|
||||
if (grandchild == nullptr ||
|
||||
grandchild->OpType() != "DequantizeLinear") {
|
||||
return false;
|
||||
}
|
||||
const auto get_constant_initializer = [&graph](const std::string& initializer_name) {
|
||||
return graph.GetConstantInitializer(initializer_name, true);
|
||||
};
|
||||
if (!QDQ::IsQDQPairSupported(*parent, *self, get_constant_initializer, graph.ModelPath()) ||
|
||||
!QDQ::IsQDQPairSupported(*child, *grandchild, get_constant_initializer, graph.ModelPath())) {
|
||||
return false;
|
||||
}
|
||||
float new_scale = 0.0f;
|
||||
if (self_zp_type == "tensor(uint8)") {
|
||||
uint8_t new_zero_point = 0;
|
||||
if (!FindNewZeroPointAndScale(graph, *self, *child, new_scale, new_zero_point)) {
|
||||
return false;
|
||||
}
|
||||
ApplyNewInputValue(graph, *grandchild, InputIndex::SCALE_ID, new_scale);
|
||||
ApplyNewInputValue(graph, *parent, InputIndex::SCALE_ID, new_scale);
|
||||
ApplyNewInputValue(graph, *grandchild, InputIndex::ZERO_POINT_ID, new_zero_point);
|
||||
ApplyNewInputValue(graph, *parent, InputIndex::ZERO_POINT_ID, new_zero_point);
|
||||
} else {
|
||||
int8_t new_zero_point = 0;
|
||||
if (!FindNewZeroPointAndScale(graph, *self, *child, new_scale, new_zero_point)) {
|
||||
return false;
|
||||
}
|
||||
ApplyNewInputValue(graph, *grandchild, InputIndex::SCALE_ID, new_scale);
|
||||
ApplyNewInputValue(graph, *parent, InputIndex::SCALE_ID, new_scale);
|
||||
ApplyNewInputValue(graph, *grandchild, InputIndex::ZERO_POINT_ID, new_zero_point);
|
||||
ApplyNewInputValue(graph, *parent, InputIndex::ZERO_POINT_ID, new_zero_point);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool DoubleQDQPairsRemover::FindNewZeroPointAndScale(const Graph& graph, const Node& node1, const Node& node2, float& new_scale, T& new_zero_point) {
|
||||
// if Q/DQ scale and zero point are not constant, return false
|
||||
const ONNX_NAMESPACE::TensorProto* node1_scale_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, node1.InputDefs()[InputIndex::SCALE_ID]->Name());
|
||||
const ONNX_NAMESPACE::TensorProto* node2_scale_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, node2.InputDefs()[InputIndex::SCALE_ID]->Name());
|
||||
const ONNX_NAMESPACE::TensorProto* node1_zp_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, node1.InputDefs()[InputIndex::ZERO_POINT_ID]->Name());
|
||||
const ONNX_NAMESPACE::TensorProto* node2_zp_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, node2.InputDefs()[InputIndex::ZERO_POINT_ID]->Name());
|
||||
Initializer zero_point_init_1{*node1_zp_tensor_proto, graph.ModelPath()};
|
||||
Initializer zero_point_init_2{*node2_zp_tensor_proto, graph.ModelPath()};
|
||||
Initializer scale_init_1{*node1_scale_tensor_proto, graph.ModelPath()};
|
||||
Initializer scale_init_2{*node2_scale_tensor_proto, graph.ModelPath()};
|
||||
if (zero_point_init_1.data_type() != zero_point_init_2.data_type() ||
|
||||
scale_init_1.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
|
||||
scale_init_2.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
T zero_point_1 = zero_point_init_1.data<T>()[0];
|
||||
T zero_point_2 = zero_point_init_2.data<T>()[0];
|
||||
const float scale_1 = scale_init_1.data<float>()[0];
|
||||
const float scale_2 = scale_init_2.data<float>()[0];
|
||||
T q_min = std::numeric_limits<T>::min();
|
||||
T q_max = std::numeric_limits<T>::max();
|
||||
|
||||
float real_min1 = gsl::narrow_cast<float>(q_min - zero_point_1) * scale_1;
|
||||
float real_max1 = gsl::narrow_cast<float>(q_max - zero_point_1) * scale_1;
|
||||
float real_min2 = gsl::narrow_cast<float>(q_min - zero_point_2) * scale_2;
|
||||
float real_max2 = gsl::narrow_cast<float>(q_max - zero_point_2) * scale_2;
|
||||
|
||||
const float real_min = std::max(real_min1, real_min2);
|
||||
const float real_max = std::min(real_max1, real_max2);
|
||||
|
||||
new_scale = (real_max - real_min) / gsl::narrow_cast<float>(q_max - q_min);
|
||||
new_zero_point = gsl::narrow_cast<T>(std::round(gsl::narrow_cast<float>(q_min) - real_min / new_scale));
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DoubleQDQPairsRemover::ApplyNewInputValue(Graph& graph, Node& node, const InputIndex& index, T value) {
|
||||
const auto* input_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[index]->Name());
|
||||
Initializer input_init{*input_tensor, graph.ModelPath()};
|
||||
TensorProto new_input_tensor(*input_tensor);
|
||||
input_init.data<T>()[0] = value;
|
||||
input_init.ToProto(new_input_tensor);
|
||||
auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name());
|
||||
new_input_tensor.set_name(new_name);
|
||||
NodeArg& new_input = graph_utils::AddInitializer(graph, new_input_tensor);
|
||||
graph_utils::ReplaceNodeInput(node, index, new_input);
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
53
onnxruntime/core/optimizer/double_qdq_pairs_remover.h
Normal file
53
onnxruntime/core/optimizer/double_qdq_pairs_remover.h
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/optimizer/qdq_transformer/qdq_util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
using ONNX_NAMESPACE::TensorProto;
|
||||
using ONNX_NAMESPACE::TensorProto_DataType;
|
||||
using QDQ::InputIndex;
|
||||
|
||||
/**
|
||||
* @Class DoubleQDQPairsRemover
|
||||
* @brief Remove one pair of Q-DQ from Double Q-DQ pairs.
|
||||
*/
|
||||
class DoubleQDQPairsRemover : public GraphTransformer {
|
||||
public:
|
||||
DoubleQDQPairsRemover() : GraphTransformer("DoubleQDQPairsRemover", {}) {}
|
||||
|
||||
private:
|
||||
Status ApplyImpl(
|
||||
Graph& graph,
|
||||
bool& modified,
|
||||
int graph_level,
|
||||
const logging::Logger& logger) const override;
|
||||
|
||||
static bool IsNodeRemovable(
|
||||
Graph& graph,
|
||||
const NodeIndex& self_index,
|
||||
NodeIndex& parent_index,
|
||||
NodeIndex& child_index,
|
||||
NodeIndex& grandchild_index);
|
||||
|
||||
template <typename T>
|
||||
static bool FindNewZeroPointAndScale(
|
||||
const Graph& graph,
|
||||
const Node& node1,
|
||||
const Node& node2,
|
||||
float& new_scale,
|
||||
T& new_zero_point);
|
||||
|
||||
template <typename T>
|
||||
static void ApplyNewInputValue(
|
||||
Graph& graph,
|
||||
Node& node,
|
||||
const InputIndex& index,
|
||||
T value);
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -33,6 +33,7 @@
|
|||
#include "core/optimizer/conv_bn_fusion.h"
|
||||
#include "core/optimizer/conv_mul_fusion.h"
|
||||
#include "core/optimizer/div_mul_fusion.h"
|
||||
#include "core/optimizer/double_qdq_pairs_remover.h"
|
||||
#include "core/optimizer/dropout_elimination.h"
|
||||
#include "core/optimizer/dynamic_quantize_matmul_fusion.h"
|
||||
#include "core/optimizer/embed_layer_norm_fusion.h"
|
||||
|
|
@ -45,6 +46,7 @@
|
|||
#include "core/optimizer/gemm_activation_fusion.h"
|
||||
#include "core/optimizer/gemm_sum_fusion.h"
|
||||
#include "core/optimizer/gemm_transpose_fusion.h"
|
||||
#include "core/optimizer/identical_children_consolidation.h"
|
||||
#include "core/optimizer/identity_elimination.h"
|
||||
#include "core/optimizer/layer_norm_fusion.h"
|
||||
#include "core/optimizer/matmul_add_fusion.h"
|
||||
|
|
@ -66,7 +68,6 @@
|
|||
#include "core/optimizer/slice_elimination.h"
|
||||
#include "core/optimizer/transpose_optimizer/ort_transpose_optimizer.h"
|
||||
#include "core/optimizer/unsqueeze_elimination.h"
|
||||
#include "core/optimizer/identical_children_consolidation.h"
|
||||
#ifdef ENABLE_TRAINING_CORE
|
||||
#include "orttraining/core/optimizer/bitmask_dropout_replacement.h"
|
||||
#include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h"
|
||||
|
|
@ -191,12 +192,17 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
transformers.emplace_back(std::move(rule_transformer));
|
||||
}
|
||||
|
||||
// We need to remove the duplicated QDQ Pairs before all other GraphTransformation.
|
||||
|
||||
// no filtering on execution provider for L1 optimizations as they only use official ONNX operators
|
||||
|
||||
// Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for
|
||||
// CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by
|
||||
// default, CSE will not merge them, because the different initializers are represented by different NodeArg.
|
||||
transformers.emplace_back(std::make_unique<IdenticalChildrenConsolidation>());
|
||||
if (session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableDoubleQDQRemover, "0") == "0"){
|
||||
transformers.emplace_back(std::make_unique<IdenticalChildrenConsolidation>());
|
||||
transformers.emplace_back(std::make_unique<DoubleQDQPairsRemover>());
|
||||
}
|
||||
transformers.emplace_back(std::make_unique<ConstantSharing>());
|
||||
transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>());
|
||||
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq));
|
||||
|
|
|
|||
|
|
@ -3555,6 +3555,30 @@ static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_opt
|
|||
EXPECT_EQ(has_gelu_approximation, is_enabled);
|
||||
}
|
||||
|
||||
// Test session option configuration for DoubleQDQPairsRemover
|
||||
TEST_F(GraphTransformationTests, DoubleQDQRemover_SessionOptionConfig) {
|
||||
auto verify_session_config = [&](bool is_enabled, SessionOptions& session_option) {
|
||||
std::unique_ptr<CPUExecutionProvider> cpu_ep = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
|
||||
bool has_double_qdq_remover = false;
|
||||
auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(), {});
|
||||
for (auto& transformer : transformers) {
|
||||
if (transformer->Name() == "DoubleQDQPairsRemover") {
|
||||
has_double_qdq_remover = true;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(has_double_qdq_remover, is_enabled);
|
||||
};
|
||||
SessionOptions session_options;
|
||||
// DoubleQDQPairsRemover is enabled by default.
|
||||
verify_session_config(true, session_options);
|
||||
|
||||
ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableDoubleQDQRemover, "1"));
|
||||
verify_session_config(false, session_options);
|
||||
|
||||
ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableDoubleQDQRemover, "0"));
|
||||
verify_session_config(true, session_options);
|
||||
}
|
||||
|
||||
// Test session option configuration for GeluApproximation
|
||||
TEST_F(GraphTransformationTests, GeluApproximation_SessionOptionConfig) {
|
||||
SessionOptions session_options;
|
||||
|
|
|
|||
|
|
@ -395,6 +395,46 @@ GetQDQTestCaseFn BuildConsolidationTestCase(
|
|||
};
|
||||
}
|
||||
|
||||
template <typename Type1, typename Type2, typename Type3, typename Type4>
|
||||
GetQDQTestCaseFn BuildDoubleQDQTestCases(Type1 zp_1, Type2 zp_2, Type3 zp_3, Type4 zp_4,
|
||||
float scale_1, float scale_2, float scale_3, float scale_4) {
|
||||
return [=](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<float>(
|
||||
{11, 22, 33, 44},
|
||||
std::numeric_limits<Type1>::min() * (scale_1 + scale_3) / 2,
|
||||
std::numeric_limits<Type1>::max() * (scale_1 + scale_3) / 2);
|
||||
NodeArg* q1_output = builder.MakeIntermediate();
|
||||
NodeArg* dq1_output = builder.MakeIntermediate();
|
||||
NodeArg* q2_output = builder.MakeIntermediate();
|
||||
NodeArg* dq2_output = builder.MakeOutput();
|
||||
builder.AddQuantizeLinearNode<Type1>(input_arg, scale_1, zp_1, q1_output);
|
||||
builder.AddDequantizeLinearNode<Type2>(q1_output, scale_2, zp_2, dq1_output);
|
||||
builder.AddQuantizeLinearNode<Type3>(dq1_output, scale_3, zp_3, q2_output);
|
||||
builder.AddDequantizeLinearNode<Type4>(q2_output, scale_4, zp_4, dq2_output);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
GetQDQTestCaseFn BuildDoubleQDQWithoutLastOutput(int output_index) {
|
||||
return [=](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<float>({2, 3, 4}, std::numeric_limits<float>::min(), std::numeric_limits<float>::max());
|
||||
T zp = (std::numeric_limits<T>::max() - std::numeric_limits<T>::min()) / 2;
|
||||
float scale = 0.003f;
|
||||
std::vector<NodeArg*> outputs(4);
|
||||
for (auto i = 0; i < 4; i++) {
|
||||
if (output_index == i) {
|
||||
outputs[i] = builder.MakeOutput();
|
||||
} else {
|
||||
outputs[i] = builder.MakeIntermediate();
|
||||
}
|
||||
}
|
||||
builder.AddQuantizeLinearNode<T>(input_arg, scale, zp, outputs[0]);
|
||||
builder.AddDequantizeLinearNode<T>(outputs[0], scale, zp, outputs[1]);
|
||||
builder.AddQuantizeLinearNode<T>(outputs[1], scale, zp, outputs[2]);
|
||||
builder.AddDequantizeLinearNode<T>(outputs[2], scale, zp, outputs[3]);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename InputType, typename OutputType>
|
||||
GetQDQTestCaseFn BuildQDQSplitTestCase(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/providers/partitioning_utils.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
#include "core/session/environment.h"
|
||||
|
|
@ -764,6 +765,99 @@ TEST(QDQTransformerTests, Gather) {
|
|||
test_case({12, 37}, {24, 12});
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, DoubleQDQ) {
|
||||
constexpr uint8_t good_u8_1 = 80;
|
||||
constexpr uint8_t good_u8_2 = 40;
|
||||
constexpr uint8_t bad_u8 = 13;
|
||||
|
||||
constexpr int8_t good_s8_1 = 99;
|
||||
constexpr int8_t good_s8_2 = -112;
|
||||
constexpr int8_t bad_s8 = 42;
|
||||
|
||||
constexpr float good_float_point_1 = 4.0f;
|
||||
constexpr float good_float_point_2 = 8.0f;
|
||||
constexpr float bad_float_point = 1.11f;
|
||||
|
||||
std::function<void(InferenceSessionWrapper & session)> expect_succeed = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["QuantizeLinear"], 1);
|
||||
EXPECT_EQ(op_to_count["DequantizeLinear"], 1);
|
||||
};
|
||||
std::function<void(InferenceSessionWrapper & session)> expect_fail = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["QuantizeLinear"], 2);
|
||||
EXPECT_EQ(op_to_count["DequantizeLinear"], 2);
|
||||
};
|
||||
|
||||
auto test_case_all_u8 = [&](bool succeed,
|
||||
uint8_t zp_1, uint8_t zp_2, uint8_t zp_3, uint8_t zp_4,
|
||||
float scale_1, float scale_2, float scale_3, float scale_4) {
|
||||
TransformerTester(
|
||||
BuildDoubleQDQTestCases<uint8_t, uint8_t, uint8_t, uint8_t>(zp_1, zp_2, zp_3, zp_4, scale_1, scale_2, scale_3, scale_4),
|
||||
succeed ? expect_succeed : expect_fail,
|
||||
TransformerLevel::Default,
|
||||
TransformerLevel::Level1,
|
||||
12,
|
||||
(scale_1 + scale_3) / 2,
|
||||
0.01);
|
||||
};
|
||||
|
||||
auto test_case_all_s8 = [&](bool succeed,
|
||||
int8_t zp_1, int8_t zp_2, int8_t zp_3, int8_t zp_4,
|
||||
float scale_1, float scale_2, float scale_3, float scale_4) {
|
||||
TransformerTester(
|
||||
BuildDoubleQDQTestCases<int8_t, int8_t, int8_t, int8_t>(zp_1, zp_2, zp_3, zp_4, scale_1, scale_2, scale_3, scale_4),
|
||||
succeed ? expect_succeed : expect_fail,
|
||||
TransformerLevel::Default,
|
||||
TransformerLevel::Level1,
|
||||
12,
|
||||
(scale_1 + scale_3) / 2,
|
||||
0.01);
|
||||
};
|
||||
|
||||
auto test_case_2u8_2s8_failed = [&](uint8_t zp_1, uint8_t zp_2, int8_t zp_3, int8_t zp_4,
|
||||
float scale_1, float scale_2, float scale_3, float scale_4) {
|
||||
TransformerTester(
|
||||
BuildDoubleQDQTestCases<uint8_t, uint8_t, int8_t, int8_t>(zp_1, zp_2, zp_3, zp_4, scale_1, scale_2, scale_3, scale_4),
|
||||
expect_fail,
|
||||
TransformerLevel::Default,
|
||||
TransformerLevel::Level1);
|
||||
};
|
||||
|
||||
// all unsigned type
|
||||
test_case_all_u8(true, good_u8_1, good_u8_1, good_u8_2, good_u8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2);
|
||||
// all signed type
|
||||
test_case_all_s8(true, good_s8_1, good_s8_1, good_s8_2, good_s8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2);
|
||||
// 2 signed, 2 unsigned
|
||||
test_case_2u8_2s8_failed(good_u8_1, good_u8_1, good_s8_2, good_s8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2);
|
||||
// different zero point within a pair
|
||||
test_case_all_u8(false, good_u8_1, bad_u8, good_u8_2, good_u8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2);
|
||||
test_case_all_u8(false, good_u8_1, good_u8_1, good_u8_2, bad_u8, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2);
|
||||
test_case_all_s8(false, good_s8_1, bad_s8, good_s8_2, good_s8_2, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2);
|
||||
test_case_all_s8(false, good_s8_1, good_s8_1, good_s8_2, bad_s8, good_float_point_1, good_float_point_1, good_float_point_2, good_float_point_2);
|
||||
// different scale within a pair
|
||||
test_case_all_u8(false, good_u8_1, good_u8_1, good_u8_2, good_u8_2, good_float_point_1, bad_float_point, good_float_point_2, good_float_point_2);
|
||||
test_case_all_u8(false, good_u8_1, good_u8_1, good_u8_2, good_u8_2, good_float_point_1, good_float_point_1, bad_float_point, good_float_point_2);
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) {
|
||||
auto test_case = [&](int output_index, int expected_Q_count, int expected_DQ_count) {
|
||||
auto graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["QuantizeLinear"], expected_Q_count);
|
||||
EXPECT_EQ(op_to_count["DequantizeLinear"], expected_DQ_count);
|
||||
};
|
||||
TransformerTester(
|
||||
BuildDoubleQDQWithoutLastOutput<uint8_t>(output_index),
|
||||
graph,
|
||||
TransformerLevel::Default,
|
||||
TransformerLevel::Level1);
|
||||
};
|
||||
test_case(0, 2, 2);
|
||||
test_case(1, 2, 2);
|
||||
test_case(2, 2, 2);
|
||||
test_case(3, 1, 1);
|
||||
}
|
||||
// Because split isn't one the supported ops, this will stay the same
|
||||
TEST(QDQTransformerTests, Split) {
|
||||
auto test_case = [&](const std::vector<int64_t>& input_shape, const int64_t& axis) {
|
||||
|
|
@ -2585,13 +2679,8 @@ TEST(QDQTransformerTests, QDQFinalCleanupTransformer_BasicDQQCleanUp) {
|
|||
auto check_graph = [&](const InferenceSessionWrapper& session) {
|
||||
const auto ops_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph());
|
||||
const auto expected_ops_in_order = [&]() -> std::vector<std::string> {
|
||||
if (use_matching_qdq_params) {
|
||||
// DQ/Q cleanup removes middle DQ/Q
|
||||
return {"QuantizeLinear", "DequantizeLinear"};
|
||||
}
|
||||
|
||||
// removes nothing
|
||||
return {"QuantizeLinear", "DequantizeLinear", "QuantizeLinear", "DequantizeLinear"};
|
||||
// In either case both DQ and Q will be removed and fused due to DoubleQDQPairsRemover
|
||||
return {"QuantizeLinear", "DequantizeLinear"};
|
||||
}();
|
||||
|
||||
EXPECT_EQ(ops_in_order, expected_ops_in_order);
|
||||
|
|
|
|||
Loading…
Reference in a new issue