From 22c475520e8291ee4551ef29d87c1133b2c8d058 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Wed, 9 Mar 2022 16:17:00 -0800 Subject: [PATCH] Make QDQSelectorActionTransformer() is_int8_allowed parameter required. (#10823) Make QDQSelectorActionTransformer() is_int8_allowed parameter required. Set it to QDQIsInt8Allowed() in places it was previously set to false. --- .../core/optimizer/graph_transformer_utils.cc | 7 +++++-- .../qdq_selector_action_transformer.cc | 2 +- .../qdq_selector_action_transformer.h | 2 +- .../test/optimizer/qdq_transformer_test.cc | 18 +++++++++--------- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index ab23df4dba..21216dedc1 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -217,7 +217,7 @@ InlinedVector> GenerateTransformers( if (!qdq_is_int8_allowed) { transformers.emplace_back(std::make_unique(cpu_ep)); } - transformers.emplace_back(std::make_unique(SatApplyContextVariant{}, qdq_is_int8_allowed)); + transformers.emplace_back(std::make_unique(qdq_is_int8_allowed)); } transformers.emplace_back(std::make_unique(cpu_ep)); @@ -295,12 +295,15 @@ InlinedVector> GenerateTransformersForMinimalB #if !defined(DISABLE_CONTRIB_OPS) const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; + const bool qdq_is_int8_allowed = + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQIsInt8Allowed, + QDQIsInt8Allowed() ? "1" : "0") == "1"; // runtime optimizations only support CPU EP now const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; if (!disable_quant_qdq) { - transformers.emplace_back(std::make_unique(apply_context)); + transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index ee4c3d7d65..9e75a0fbad 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -209,7 +209,7 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { } // namespace QDQSelectorActionTransformer::QDQSelectorActionTransformer( - const SatApplyContextVariant& apply_context, bool is_int8_allowed) + bool is_int8_allowed, const SatApplyContextVariant& apply_context) : SelectorActionTransformer{ "QDQSelectorActionTransformer", CreateSelectorActionRegistry(is_int8_allowed), diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index 2c48109de1..506834370f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -22,7 +22,7 @@ Transformer that fuses QDQ and fp32 ops into quantized ops. */ class QDQSelectorActionTransformer : public SelectorActionTransformer { public: - QDQSelectorActionTransformer(const SatApplyContextVariant& apply_context = {}, bool is_int8_allowed = false); + QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}); }; } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 3b214845c8..0956903f6b 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -75,7 +75,7 @@ void QDQTransformerConvTests() { 12 /*opset_version*/, 0.01 /*per_sample_tolerance*/, 0.01 /*relative_per_sample_tolerance*/, - std::make_unique()); + std::make_unique(QDQIsInt8Allowed())); }; test_case({1, 12, 37}, {32, 12, 5}); @@ -252,7 +252,7 @@ void QDQTransformerAveragePoolTests() { 12 /*opset_version*/, 0.01 /*per_sample_tolerance*/, 0.01 /*relative_per_sample_tolerance*/, - std::make_unique()); + std::make_unique(QDQIsInt8Allowed())); }; test_case({1, 12, 37}); @@ -301,7 +301,7 @@ void QDQTransformerGlobalAveragePoolTests() { 12 /*opset_version*/, 0.01 /*per_sample_tolerance*/, 0.01 /*relative_per_sample_tolerance*/, - std::make_unique()); + std::make_unique(QDQIsInt8Allowed())); }; test_case({1, 12, 37}); @@ -351,7 +351,7 @@ void QDQTransformerBinaryOpTests(const std::string& op_type) { 12 /*opset_version*/, 0.01 /*per_sample_tolerance*/, 0.01 /*relative_per_sample_tolerance*/, - std::make_unique()); + std::make_unique(QDQIsInt8Allowed())); }; test_case({1, 12, 37}); @@ -482,7 +482,7 @@ void QDQTransformerMatMulTests(bool has_output_q) { 12 /*opset_version*/, 0.01 /*per_sample_tolerance*/, 0.01 /*relative_per_sample_tolerance*/, - std::make_unique()); + std::make_unique(QDQIsInt8Allowed())); }; test_case({1, 2, 2}, {1, 2, 4}); @@ -637,7 +637,7 @@ void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one 12 /*opset_version*/, 0.01 /*per_sample_tolerance*/, 0.01 /*relative_per_sample_tolerance*/, - std::make_unique()); + std::make_unique(QDQIsInt8Allowed())); }; test_case({2, 2}, {2, 4}); @@ -1366,7 +1366,7 @@ void QDQTransformerLeakyReluTests() { 12 /*opset_version*/, 0.01 /*per_sample_tolerance*/, 0.01 /*relative_per_sample_tolerance*/, - std::make_unique()); + std::make_unique(QDQIsInt8Allowed())); }; test_case({1, 12, 37}); @@ -1435,7 +1435,7 @@ void QDQTransformerSigmoidTests() { 12 /*opset_version*/, 0.01 /*per_sample_tolerance*/, 0.01 /*relative_per_sample_tolerance*/, - std::make_unique()); + std::make_unique(QDQIsInt8Allowed())); }; test_case({1, 12, 37}); @@ -1804,7 +1804,7 @@ TEST(QDQTransformerTests, Concat) { 12 /*opset_version*/, 0.01f /*per_sample_tolerance*/, 0.01f /*relative_per_sample_tolerance*/, - std::make_unique()); + std::make_unique(QDQIsInt8Allowed())); }; test_case({{1, 6, 36}, {1, 3, 36}}, 1);