Handle all float types in IsQDQPairSupported (#19085)

### Description
This makes detection of identical QDQ scales work with float16 and
bfloat16 rather than failing.


### Motivation and Context
This addresses failures in customer models
This commit is contained in:
Jeff Bloomfield 2024-01-11 15:16:44 -08:00 committed by GitHub
parent 8a0a972f39
commit 08cf4fbcad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 31 deletions

View file

@ -54,9 +54,26 @@ bool IsQDQPairSupported(
Initializer dq_zp(*dq_zp_tensor_proto, model_path);
Initializer dq_scale(*dq_scale_tensor_proto, model_path);
return q_zp.data_type() == dq_zp.data_type() &&
SpanEq(q_zp.DataAsByteSpan(), dq_zp.DataAsByteSpan()) &&
*q_scale.data<float>() == *dq_scale.data<float>();
if (q_zp.data_type() != dq_zp.data_type() ||
q_scale.data_type() != q_scale.data_type() ||
!SpanEq(q_zp.DataAsByteSpan(), dq_zp.DataAsByteSpan())) {
return false;
}
switch (q_scale.data_type()) {
case ONNX_NAMESPACE::TensorProto::FLOAT:
return *q_scale.data<float>() == *dq_scale.data<float>();
case ONNX_NAMESPACE::TensorProto::FLOAT16:
return *q_scale.data<MLFloat16>() == *dq_scale.data<MLFloat16>();
case ONNX_NAMESPACE::TensorProto::BFLOAT16:
return *q_scale.data<BFloat16>() == *dq_scale.data<BFloat16>();
default:
assert(false);
return false;
}
}
bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) {

View file

@ -4602,38 +4602,43 @@ TEST_F(GraphTransformationTests, GeluApproximation_SessionOptionConfig) {
}
// Test DoubleQDQPairsRemover to remove unnecessary DQ->Q nodes in the middle
TEST_F(GraphTransformationTests, DoublQDQRemover_RemoveDupQDQ) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "qdq_optimization/dup_qdq.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
TEST_F(GraphTransformationTests, DoublQDQRemover_RemoveDupQDQ_Float16) {
auto RunTest = [this](const ORTCHAR_T* model_uri) {
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<DoubleQDQPairsRemover>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<DoubleQDQPairsRemover>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["QuantizeLinear"], 3);
EXPECT_EQ(op_to_count["DequantizeLinear"], 4);
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["QuantizeLinear"], 3);
EXPECT_EQ(op_to_count["DequantizeLinear"], 4);
std::string dq_scale_name_before_reshape_node;
std::string zp_name_before_reshape_node;
std::string dq_scale_name_after_reshape_node;
std::string zp_name_after_reshape_node;
for (auto& node : graph.Nodes()) {
if (node.Name() == "dq_2") {
dq_scale_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
zp_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
std::string dq_scale_name_before_reshape_node;
std::string zp_name_before_reshape_node;
std::string dq_scale_name_after_reshape_node;
std::string zp_name_after_reshape_node;
for (auto& node : graph.Nodes()) {
if (node.Name() == "dq_2") {
dq_scale_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
zp_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
}
if (node.Name() == "q_3") {
dq_scale_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
zp_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
}
}
if (node.Name() == "q_3") {
dq_scale_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
zp_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
}
}
EXPECT_EQ(dq_scale_name_before_reshape_node.empty(), false);
EXPECT_EQ(zp_name_before_reshape_node.empty(), false);
EXPECT_EQ(dq_scale_name_before_reshape_node, dq_scale_name_after_reshape_node);
EXPECT_EQ(zp_name_before_reshape_node, zp_name_after_reshape_node);
EXPECT_EQ(dq_scale_name_before_reshape_node.empty(), false);
EXPECT_EQ(zp_name_before_reshape_node.empty(), false);
EXPECT_EQ(dq_scale_name_before_reshape_node, dq_scale_name_after_reshape_node);
EXPECT_EQ(zp_name_before_reshape_node, zp_name_after_reshape_node);
};
RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq.onnx");
RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq_float16.onnx");
RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq_bfloat16.onnx");
}
// Test Gelu -> FastGelu