Make QDQSelectorActionTransformer() is_int8_allowed parameter required. (#10823)

Make QDQSelectorActionTransformer() is_int8_allowed parameter required.
Set it to QDQIsInt8Allowed() in places it was previously set to false.
This commit is contained in:
Edward Chen 2022-03-09 16:17:00 -08:00 committed by GitHub
parent cc6bc34c8c
commit 22c475520e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 13 deletions

View file

@ -217,7 +217,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
if (!qdq_is_int8_allowed) {
transformers.emplace_back(std::make_unique<QDQS8ToU8Transformer>(cpu_ep));
}
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(SatApplyContextVariant{}, qdq_is_int8_allowed));
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed));
}
transformers.emplace_back(std::make_unique<GemmActivationFusion>(cpu_ep));
@ -295,12 +295,15 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
#if !defined(DISABLE_CONTRIB_OPS)
const bool disable_quant_qdq =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1";
const bool qdq_is_int8_allowed =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQIsInt8Allowed,
QDQIsInt8Allowed() ? "1" : "0") == "1";
// runtime optimizations only support CPU EP now
const InlinedHashSet<std::string_view> cpu_ep = {onnxruntime::kCpuExecutionProvider};
if (!disable_quant_qdq) {
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(apply_context));
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed, apply_context));
}
transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_ep, apply_context));

View file

@ -209,7 +209,7 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) {
} // namespace
QDQSelectorActionTransformer::QDQSelectorActionTransformer(
const SatApplyContextVariant& apply_context, bool is_int8_allowed)
bool is_int8_allowed, const SatApplyContextVariant& apply_context)
: SelectorActionTransformer{
"QDQSelectorActionTransformer",
CreateSelectorActionRegistry(is_int8_allowed),

View file

@ -22,7 +22,7 @@ Transformer that fuses QDQ and fp32 ops into quantized ops.
*/
class QDQSelectorActionTransformer : public SelectorActionTransformer {
public:
QDQSelectorActionTransformer(const SatApplyContextVariant& apply_context = {}, bool is_int8_allowed = false);
QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {});
};
} // namespace onnxruntime

View file

@ -75,7 +75,7 @@ void QDQTransformerConvTests() {
12 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>());
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 12, 37}, {32, 12, 5});
@ -252,7 +252,7 @@ void QDQTransformerAveragePoolTests() {
12 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>());
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 12, 37});
@ -301,7 +301,7 @@ void QDQTransformerGlobalAveragePoolTests() {
12 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>());
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 12, 37});
@ -351,7 +351,7 @@ void QDQTransformerBinaryOpTests(const std::string& op_type) {
12 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>());
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 12, 37});
@ -482,7 +482,7 @@ void QDQTransformerMatMulTests(bool has_output_q) {
12 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>());
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 2, 2}, {1, 2, 4});
@ -637,7 +637,7 @@ void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one
12 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>());
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({2, 2}, {2, 4});
@ -1366,7 +1366,7 @@ void QDQTransformerLeakyReluTests() {
12 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>());
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 12, 37});
@ -1435,7 +1435,7 @@ void QDQTransformerSigmoidTests() {
12 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>());
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 12, 37});
@ -1804,7 +1804,7 @@ TEST(QDQTransformerTests, Concat) {
12 /*opset_version*/,
0.01f /*per_sample_tolerance*/,
0.01f /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>());
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({{1, 6, 36}, {1, 3, 36}}, 1);