mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Handle all float types in IsQDQPairSupported (#19085)
### Description This makes detection of identical QDQ scales work with float16 and bfloat16 rather than failing. ### Motivation and Context This addresses failures in customer models
This commit is contained in:
parent
8a0a972f39
commit
08cf4fbcad
4 changed files with 53 additions and 31 deletions
|
|
@ -54,9 +54,26 @@ bool IsQDQPairSupported(
|
|||
Initializer dq_zp(*dq_zp_tensor_proto, model_path);
|
||||
Initializer dq_scale(*dq_scale_tensor_proto, model_path);
|
||||
|
||||
return q_zp.data_type() == dq_zp.data_type() &&
|
||||
SpanEq(q_zp.DataAsByteSpan(), dq_zp.DataAsByteSpan()) &&
|
||||
*q_scale.data<float>() == *dq_scale.data<float>();
|
||||
if (q_zp.data_type() != dq_zp.data_type() ||
|
||||
q_scale.data_type() != q_scale.data_type() ||
|
||||
!SpanEq(q_zp.DataAsByteSpan(), dq_zp.DataAsByteSpan())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (q_scale.data_type()) {
|
||||
case ONNX_NAMESPACE::TensorProto::FLOAT:
|
||||
return *q_scale.data<float>() == *dq_scale.data<float>();
|
||||
|
||||
case ONNX_NAMESPACE::TensorProto::FLOAT16:
|
||||
return *q_scale.data<MLFloat16>() == *dq_scale.data<MLFloat16>();
|
||||
|
||||
case ONNX_NAMESPACE::TensorProto::BFLOAT16:
|
||||
return *q_scale.data<BFloat16>() == *dq_scale.data<BFloat16>();
|
||||
|
||||
default:
|
||||
assert(false);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) {
|
||||
|
|
|
|||
|
|
@ -4602,38 +4602,43 @@ TEST_F(GraphTransformationTests, GeluApproximation_SessionOptionConfig) {
|
|||
}
|
||||
|
||||
// Test DoubleQDQPairsRemover to remove unnecessary DQ->Q nodes in the middle
|
||||
TEST_F(GraphTransformationTests, DoublQDQRemover_RemoveDupQDQ) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "qdq_optimization/dup_qdq.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
TEST_F(GraphTransformationTests, DoublQDQRemover_RemoveDupQDQ_Float16) {
|
||||
auto RunTest = [this](const ORTCHAR_T* model_uri) {
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<DoubleQDQPairsRemover>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<DoubleQDQPairsRemover>(), TransformerLevel::Level1));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
EXPECT_EQ(op_to_count["QuantizeLinear"], 3);
|
||||
EXPECT_EQ(op_to_count["DequantizeLinear"], 4);
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
EXPECT_EQ(op_to_count["QuantizeLinear"], 3);
|
||||
EXPECT_EQ(op_to_count["DequantizeLinear"], 4);
|
||||
|
||||
std::string dq_scale_name_before_reshape_node;
|
||||
std::string zp_name_before_reshape_node;
|
||||
std::string dq_scale_name_after_reshape_node;
|
||||
std::string zp_name_after_reshape_node;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.Name() == "dq_2") {
|
||||
dq_scale_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
|
||||
zp_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
|
||||
std::string dq_scale_name_before_reshape_node;
|
||||
std::string zp_name_before_reshape_node;
|
||||
std::string dq_scale_name_after_reshape_node;
|
||||
std::string zp_name_after_reshape_node;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.Name() == "dq_2") {
|
||||
dq_scale_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
|
||||
zp_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
|
||||
}
|
||||
if (node.Name() == "q_3") {
|
||||
dq_scale_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
|
||||
zp_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
|
||||
}
|
||||
}
|
||||
if (node.Name() == "q_3") {
|
||||
dq_scale_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
|
||||
zp_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(dq_scale_name_before_reshape_node.empty(), false);
|
||||
EXPECT_EQ(zp_name_before_reshape_node.empty(), false);
|
||||
EXPECT_EQ(dq_scale_name_before_reshape_node, dq_scale_name_after_reshape_node);
|
||||
EXPECT_EQ(zp_name_before_reshape_node, zp_name_after_reshape_node);
|
||||
EXPECT_EQ(dq_scale_name_before_reshape_node.empty(), false);
|
||||
EXPECT_EQ(zp_name_before_reshape_node.empty(), false);
|
||||
EXPECT_EQ(dq_scale_name_before_reshape_node, dq_scale_name_after_reshape_node);
|
||||
EXPECT_EQ(zp_name_before_reshape_node, zp_name_after_reshape_node);
|
||||
};
|
||||
|
||||
RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq.onnx");
|
||||
RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq_float16.onnx");
|
||||
RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq_bfloat16.onnx");
|
||||
}
|
||||
|
||||
// Test Gelu -> FastGelu
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/qdq_optimization/dup_qdq_bfloat16.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/qdq_optimization/dup_qdq_bfloat16.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/qdq_optimization/dup_qdq_float16.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/qdq_optimization/dup_qdq_float16.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue