From d337fa90e74b0985bc2fa3c2d9b8cf1c27e07b3f Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Thu, 29 Apr 2021 14:40:41 -0700 Subject: [PATCH] Propagate QDQ only when scale and zp are scalar (#7492) fix crash when DeQuantizeLinear's output is graph output propagate only when scale and zp are scalar. fix bug for is_modified= is_modified || TryCancelOutDQQPair(graph, dq_node, q_node); in which TryCancelOutDQQPair wouldn't be invoked if is_modified is true --- .../qdq_transformer/qdq_propagation.cc | 103 +++++++------ onnxruntime/core/optimizer/utils.cc | 6 +- onnxruntime/core/optimizer/utils.h | 3 + .../test/optimizer/qdq_transformer_test.cc | 142 +++++++++++++++++- 4 files changed, 201 insertions(+), 53 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc index 7e72bcad29..0547abdbff 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_propagation.cc @@ -13,7 +13,7 @@ namespace onnxruntime { static constexpr size_t QDQInputCountRequired = 3; static constexpr size_t QDQInputScaleIdx = 1; -static constexpr size_t QDQInputZeroPointIdex = 2; +static constexpr size_t QDQInputZeroPointIdx = 2; static bool CanNodePropagate(const Node& node) { return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {12}) || @@ -26,22 +26,26 @@ static bool TryCancelOutDQQPair(Graph& graph, Node& dq_node, Node& q_node) { std::vector& q_input_defs = q_node.MutableInputDefs(); if (dq_input_defs.size() != QDQInputCountRequired || q_input_defs.size() != QDQInputCountRequired || + !optimizer_utils::IsScalar(*q_input_defs[QDQInputZeroPointIdx]) || + !optimizer_utils::IsScalar(*q_input_defs[QDQInputScaleIdx]) || + !optimizer_utils::IsScalar(*dq_input_defs[QDQInputZeroPointIdx]) || + !optimizer_utils::IsScalar(*dq_input_defs[QDQInputScaleIdx]) || !graph.GetNodeOutputsInGraphOutputs(q_node).empty()) { return false; } - const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = nullptr; - const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = nullptr; - const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = nullptr; - const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = nullptr; - if (!graph_utils::NodeArgIsConstant(graph, *dq_input_defs[QDQInputScaleIdx]) || - !graph_utils::NodeArgIsConstant(graph, *q_input_defs[QDQInputScaleIdx]) || - !graph_utils::NodeArgIsConstant(graph, *dq_input_defs[QDQInputZeroPointIdex]) || - !graph_utils::NodeArgIsConstant(graph, *q_input_defs[QDQInputZeroPointIdex]) || - !graph.GetInitializedTensor(dq_input_defs[QDQInputScaleIdx]->Name(), dq_scale_tensor_proto) || - !graph.GetInitializedTensor(q_input_defs[QDQInputScaleIdx]->Name(), q_scale_tensor_proto) || - !graph.GetInitializedTensor(dq_input_defs[QDQInputZeroPointIdex]->Name(), dq_zp_tensor_proto) || - !graph.GetInitializedTensor(q_input_defs[QDQInputZeroPointIdex]->Name(), q_zp_tensor_proto)) { + const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = + graph_utils::GetConstantInitializer(graph, dq_input_defs[QDQInputScaleIdx]->Name()); + const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = + graph_utils::GetConstantInitializer(graph, q_input_defs[QDQInputScaleIdx]->Name()); + const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = + graph_utils::GetConstantInitializer(graph, dq_input_defs[QDQInputZeroPointIdx]->Name()); + const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = + graph_utils::GetConstantInitializer(graph, q_input_defs[QDQInputZeroPointIdx]->Name()); + if (nullptr == q_zp_tensor_proto || + nullptr == dq_zp_tensor_proto || + nullptr == q_scale_tensor_proto || + nullptr == dq_scale_tensor_proto) { return false; } @@ -162,32 +166,35 @@ bool QDQPropagationTransformer::PropagateDQForward(Graph& graph) const { continue; } - const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = nullptr; - const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = nullptr; - if (!graph_utils::NodeArgIsConstant(graph, *dq_input_defs[QDQInputScaleIdx]) || - !graph_utils::NodeArgIsConstant(graph, *dq_input_defs[QDQInputZeroPointIdex]) || - !graph.GetInitializedTensor(dq_input_defs[QDQInputScaleIdx]->Name(), dq_scale_tensor_proto) || - !graph.GetInitializedTensor(dq_input_defs[QDQInputZeroPointIdex]->Name(), dq_zp_tensor_proto)) { + if (!optimizer_utils::IsScalar(*dq_input_defs[QDQInputZeroPointIdx]) || + !optimizer_utils::IsScalar(*dq_input_defs[QDQInputScaleIdx])) { + continue; + } + + const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = + graph_utils::GetConstantInitializer(graph, dq_input_defs[QDQInputZeroPointIdx]->Name()); + const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = + graph_utils::GetConstantInitializer(graph, dq_input_defs[QDQInputScaleIdx]->Name()); + + if (nullptr == dq_zp_tensor_proto || nullptr == dq_scale_tensor_proto) { continue; } do { Node& next_node = *graph.GetNode(dq_node.OutputNodesBegin()->Index()); if (!CanNodePropagate(next_node)) { + // Try canceling out DQ/Q pair + if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "QuantizeLinear", {10, 13}) && + graph_utils::IsSupportedProvider(next_node, GetCompatibleExecutionProviders()) && + TryCancelOutDQQPair(graph, dq_node, next_node)) { + is_modified = true; + } + break; } SwapAdjacentNodes(graph, dq_node, next_node); is_modified = true; } while (optimizer_utils::CheckOutputEdges(graph, dq_node, 1)); - - // Cancel out DQ/Q pair - Node& q_node = *graph.GetNode(dq_node.OutputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(q_node, "QuantizeLinear", {10, 13}) || - !graph_utils::IsSupportedProvider(q_node, GetCompatibleExecutionProviders())) { - continue; - } - - is_modified = is_modified || TryCancelOutDQQPair(graph, dq_node, q_node); } return is_modified; @@ -212,16 +219,18 @@ bool QDQPropagationTransformer::PropagateQBackward(Graph& graph) const { } std::vector& q_input_defs = q_node.MutableInputDefs(); - if (q_input_defs.size() != QDQInputCountRequired) { + if (q_input_defs.size() != QDQInputCountRequired || + !optimizer_utils::IsScalar(*q_input_defs[QDQInputZeroPointIdx]) || + !optimizer_utils::IsScalar(*q_input_defs[QDQInputScaleIdx])) { continue; } - const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = nullptr; - const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = nullptr; - if (!graph_utils::NodeArgIsConstant(graph, *q_input_defs[QDQInputScaleIdx]) || - !graph_utils::NodeArgIsConstant(graph, *q_input_defs[QDQInputZeroPointIdex]) || - !graph.GetInitializedTensor(q_input_defs[QDQInputScaleIdx]->Name(), q_scale_tensor_proto) || - !graph.GetInitializedTensor(q_input_defs[QDQInputZeroPointIdex]->Name(), q_zp_tensor_proto)) { + const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = + graph_utils::GetConstantInitializer(graph, q_input_defs[QDQInputZeroPointIdx]->Name()); + const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = + graph_utils::GetConstantInitializer(graph, q_input_defs[QDQInputScaleIdx]->Name()); + + if (nullptr == q_zp_tensor_proto || nullptr == q_scale_tensor_proto) { continue; } @@ -230,27 +239,21 @@ bool QDQPropagationTransformer::PropagateQBackward(Graph& graph) const { break; } Node& prev_node = *graph.GetNode(q_node.InputNodesBegin()->Index()); - if (!optimizer_utils::CheckOutputEdges(graph, prev_node, 1) || - !CanNodePropagate(prev_node)) { + if (!optimizer_utils::CheckOutputEdges(graph, prev_node, 1)) break; + if (!CanNodePropagate(prev_node)) { + // Try canceling out DQ/Q pair + Node& dq_node = prev_node; + if (graph_utils::IsSupportedOptypeVersionAndDomain(dq_node, "DequantizeLinear", {10, 13}) && + graph_utils::IsSupportedProvider(dq_node, GetCompatibleExecutionProviders()) && + TryCancelOutDQQPair(graph, dq_node, q_node)) { + is_modified = true; + } break; } SwapAdjacentNodes(graph, prev_node, q_node); is_modified = true; } while (true); - - // Cancel out DQ/Q pair - if (q_node.InputNodesBegin() == q_node.InputNodesEnd()) { - continue; - } - Node& dq_node = *graph.GetNode(q_node.InputNodesBegin()->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(dq_node, "DequantizeLinear", {10, 13}) || - !graph_utils::IsSupportedProvider(dq_node, GetCompatibleExecutionProviders()) || - !optimizer_utils::CheckOutputEdges(graph, dq_node, 1)) { - continue; - } - - is_modified = is_modified || TryCancelOutDQQPair(graph, dq_node, q_node); } return is_modified; diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc index 6a50486390..644ea2f722 100644 --- a/onnxruntime/core/optimizer/utils.cc +++ b/onnxruntime/core/optimizer/utils.cc @@ -21,10 +21,12 @@ namespace onnxruntime { namespace optimizer_utils { bool IsFloatingPointDataType(const ONNX_NAMESPACE::TensorProto& tensor_proto) { - return tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 || tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE; + return tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || + tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 || + tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE; } -inline bool IsScalar(const NodeArg& input_arg) { +bool IsScalar(const NodeArg& input_arg) { auto shape = input_arg.Shape(); if (shape == nullptr) { // shape inferencing wasn't able to populate shape information for this NodeArg diff --git a/onnxruntime/core/optimizer/utils.h b/onnxruntime/core/optimizer/utils.h index 535edc9c77..0fbb484c68 100644 --- a/onnxruntime/core/optimizer/utils.h +++ b/onnxruntime/core/optimizer/utils.h @@ -15,6 +15,9 @@ namespace optimizer_utils { // Check if TensorProto contains a floating point type. bool IsFloatingPointDataType(const ONNX_NAMESPACE::TensorProto& tensor_proto); +// Check if NodeArg takes in a scalar tensor. +bool IsScalar(const NodeArg& input_arg); + /** Check whether a input is initializer with specified float value. @param expected_value is the expected value of the initializer. @param is_constant means whether the initializer is required to be constant. diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 04a001e98c..694dfdbdf5 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -881,8 +881,148 @@ TEST(QDQTransformerTests, QDQPropagation_QDQ_CancelOut_More) { test_case({1, 13, 13, 23}, true, true); } +TEST(QDQTransformerTests, QDQPropagation_Q_No_Parent) { + auto test_case = [&](const std::vector& input_shape, const std::vector& perms) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + // add transpose + auto* transpose_output = builder.MakeIntermediate(); + Node& transpose_node = builder.AddNode("Transpose", {input_arg}, {transpose_output}); + transpose_node.AddAttribute("perm", perms); + + // add Q + builder.AddQuantizeLinearNode(transpose_output, .0035f, 135, output_arg); + }; + + auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + GraphViewer graph_viewer(session.GetGraph()); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + EXPECT_EQ(graph_viewer.GetNode(node_topology_list[0])->OpType(), "QuantizeLinear"); + EXPECT_EQ(graph_viewer.GetNode(node_topology_list[1])->OpType(), "Transpose"); + }; + + TransformerTester(build_test_case, + check_mp_reshape_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/); + }; + + // Test the basic case of a single 1D/2D/3D convolution. + test_case({1, 13, 13, 23}, {0, 2, 3, 1}); +} + +TEST(QDQTransformerTests, QDQPropagation_DQ_No_Children) { + auto test_case = [&](const std::vector& input_shape, const std::vector& perms) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_arg, .0035f, 135, dq_output); + + // add transpose + Node& transpose_node = builder.AddNode("Transpose", {dq_output}, {output_arg}); + transpose_node.AddAttribute("perm", perms); + }; + + auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + GraphViewer graph_viewer(session.GetGraph()); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + EXPECT_EQ(graph_viewer.GetNode(node_topology_list[0])->OpType(), "Transpose"); + EXPECT_EQ(graph_viewer.GetNode(node_topology_list[1])->OpType(), "DequantizeLinear"); + }; + + TransformerTester(build_test_case, + check_mp_reshape_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/); + }; + + // Test the basic case of a single 1D/2D/3D convolution. + test_case({1, 13, 13, 23}, {0, 2, 3, 1}); +} + +TEST(QDQTransformerTests, QDQPropagation_Per_Layer_No_Propagation) { + auto test_case = [&](const std::vector& input_shape, const std::vector& perms) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + auto* dq_scale = builder.Make1DInitializer(std::vector(input_shape[1], 0.0035f)); + auto* dq_zp = builder.Make1DInitializer(std::vector(input_shape[1], 135)); + builder.AddNode("DequantizeLinear", {input_arg, dq_scale, dq_zp}, {dq_output}); + + // add transpose + Node& transpose_node = builder.AddNode("Transpose", {dq_output}, {output_arg}); + transpose_node.AddAttribute("perm", perms); + }; + + auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + GraphViewer graph_viewer(session.GetGraph()); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + EXPECT_EQ(graph_viewer.GetNode(node_topology_list[0])->OpType(), "DequantizeLinear"); + EXPECT_EQ(graph_viewer.GetNode(node_topology_list[1])->OpType(), "Transpose"); + }; + + TransformerTester(build_test_case, + check_mp_reshape_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/); + }; + + // Test the basic case of a single 1D/2D/3D convolution. + test_case({1, 13, 13, 23}, {0, 2, 3, 1}); +} + +TEST(QDQTransformerTests, QDQPropagation_DQ_Q) { + auto test_case = [&](const std::vector& input_shape) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_arg, .0035f, 135, dq_output); + + // add Q + builder.AddQuantizeLinearNode(dq_output, .0035f, 135, output_arg); + }; + + auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["QuantizeLinear"], 1); + EXPECT_EQ(op_to_count["DequantizeLinear"], 1); + }; + + TransformerTester(build_test_case, + check_mp_reshape_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/); + }; + + // Test the basic case of a single 1D/2D/3D convolution. + test_case({1, 13, 13, 23}); +} + TEST(QDQTransformerTests, Concat_UInt8) { - auto test_case = [&](const std::vector>& input_shapes, int64_t axis, bool can_trans=true) { + auto test_case = [&](const std::vector>& input_shapes, int64_t axis, bool can_trans = true) { auto build_test_case = [&](ModelTestBuilder& builder) { auto input_count = input_shapes.size(); std::vector input_args;