From ea245c94e78ef7fe6aa2ab02f2bba7e961a12388 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 24 Mar 2023 08:46:07 +1000 Subject: [PATCH] Add constant folding for simple QDQ Node Units (#15138) ### Description Currently we bail on constant folding if QDQ is enabled and we hit a DQ node. However, if we have a simple DQ -> X -> Q node unit where the DQ and X do not produce graph outputs, their output only has one consumer, and X is deterministic, we can constant fold all three nodes. Add support for this simple scenario primarily to constant fold a QDQ model that has had initializers updated by layout transformation, which results in patterns like `initializer -> DQ -> Transpose -> Q` or `initializer- > DQ -> Unsqueeze -> Q -> DQ -> Transpose -> Q` if the initializer is broadcast. ### Motivation and Context Improve end result of layout transformation on a QDQ model. --- .../core/optimizer/constant_folding.cc | 55 +++++++++++--- .../test/optimizer/graph_transform_test.cc | 71 +++++++++++++++--- ...nt_folding_qdq_node_unit.graph_output.onnx | Bin 0 -> 1335 bytes .../constant_folding_qdq_node_unit.onnx | Bin 0 -> 1933 bytes 4 files changed, 106 insertions(+), 20 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.graph_output.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.onnx diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index d32cb87fed..80e2bbedef 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -7,6 +7,7 @@ #include "core/optimizer/utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/optimizer_execution_frame.h" +#include "core/optimizer/utils.h" #include "core/framework/op_kernel.h" #include "core/framework/tensorprotoutils.h" @@ -106,11 +107,6 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } - // avoid to constant fold DequantizeLinear for QDQ format - if (skip_dequantize_linear_ && node->OpType().compare("DequantizeLinear") == 0) { - continue; - } - ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); // Updating a node may allow shape inferencing to infer output shapes of following nodes, @@ -139,15 +135,52 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, } // Check if constant folding can be applied on this node. - if (!graph_utils::IsSupportedProvider(*node, GetCompatibleExecutionProviders()) || - !optimizer_utils::IsOperationDeterministic(node->Domain(), node->OpType()) || - // constant folding does not support executing a node that includes subgraphs (control flow operators, - // such as If/Loop/Scan, fall into this category). individual nodes in the subgraph will be processed - // by the Recurse call above - node->ContainsSubgraph() || !graph_utils::AllNodeInputsAreConstant(graph, *node, constant_inputs, excluded_initializers_)) { + const auto can_constant_fold_node = [&](const Node& n, bool skip_inputs_constant_check = false) { + return graph_utils::IsSupportedProvider(n, GetCompatibleExecutionProviders()) && + optimizer_utils::IsOperationDeterministic(n.Domain(), n.OpType()) && + // constant folding does not support executing a node that includes subgraphs (control flow operators, + // such as If/Loop/Scan, fall into this category). individual nodes in the subgraph will be processed + // by the Recurse call above + !n.ContainsSubgraph() && + (skip_inputs_constant_check || + graph_utils::AllNodeInputsAreConstant(graph, n, constant_inputs, excluded_initializers_)); + }; + + if (!can_constant_fold_node(*node)) { continue; } + // if skip_dequantize_linear is true we want to maintain QDQ node units so avoid constant folding + // DequantizeLinear unless we can fold the whole QDQ node unit + if (skip_dequantize_linear_ && node->OpType() == "DequantizeLinear") { + bool can_constant_fold_qdq_node_unit = false; + + // Simplest scenario where the whole QDQ node unit of (DQ -> X -> Q) can be constant folded is if: + // - the DQ node does not produce a graph output, and its output is only consumed by X + // - X is a deterministic node with a single input and single output + // - the output from X is not a graph output and is only consumed by a Q node + if (optimizer_utils::CheckOutputEdges(graph, *node, 1)) { // DQ does not produce graph output, single consumer + const Node& node_x = *node->OutputNodesBegin(); + if (node_x.InputDefs().size() == 1 && + node_x.OutputDefs().size() == 1 && + optimizer_utils::CheckOutputEdges(graph, node_x, 1)) { + const Node& probably_q = *node_x.OutputNodesBegin(); + + if (probably_q.OpType() == "QuantizeLinear") { + // the inputs to these nodes are not const yet, but will be if we constant fold, + // so set skip_const_check to simulate that having happened + constexpr bool skip_const_check = true; + can_constant_fold_qdq_node_unit = can_constant_fold_node(node_x, skip_const_check) && + can_constant_fold_node(probably_q, skip_const_check); + } + } + } + + if (!can_constant_fold_qdq_node_unit) { + continue; + } + } + #if !defined(DISABLE_SPARSE_TENSORS) // Create execution frame for executing constant nodes. OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_, diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 813614493b..ca5011e28b 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -725,9 +725,7 @@ TEST_F(GraphTransformationTests, ConstantFoldingWithScalarShapeToInitializer) { ASSERT_TRUE(op_to_count["Add"] == 1); } -static void VerifyConstantFoldingWithDequantizeLinear(int quantize_linear_count, - int dequantize_linear_count, - int conv_count, +static void VerifyConstantFoldingWithDequantizeLinear(const std::unordered_map& expected_op_count, Graph& graph, SessionOptions& session_options, const Logger& logger) { @@ -748,9 +746,15 @@ static void VerifyConstantFoldingWithDequantizeLinear(int quantize_linear_count, ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, logger)); std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["QuantizeLinear"] == quantize_linear_count); - ASSERT_TRUE(op_to_count["DequantizeLinear"] == dequantize_linear_count); - ASSERT_TRUE(op_to_count["Conv"] == conv_count); + for (const auto& entry : expected_op_count) { + if (entry.second == 0) { + ASSERT_TRUE(op_to_count.find(entry.first) == op_to_count.end()) + << entry.first << " should not exist in the graph"; + } else { + ASSERT_TRUE(op_to_count[entry.first] == entry.second) + << entry.first << " mismatch. Expected:" << entry.second << " Got:" << op_to_count[entry.first]; + } + } } TEST_F(GraphTransformationTests, ConstantFoldingWithDequantizeLinear) { @@ -763,17 +767,66 @@ TEST_F(GraphTransformationTests, ConstantFoldingWithDequantizeLinear) { ASSERT_TRUE(op_to_count["DequantizeLinear"] == 3); ASSERT_TRUE(op_to_count["Conv"] == 1); + std::unordered_map expected_op_counts = {{"QuantizeLinear", 1}, + {"DequantizeLinear", 3}, + {"Conv", 1}}; + SessionOptions session_options; // Check DequantizeLinear aren't constant folded for default setting. - VerifyConstantFoldingWithDequantizeLinear(1, 3, 1, graph, session_options, *logger_); + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); // set kOrtSessionOptionsDisableQuantQDQ to enable it explicitly ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "0")); - VerifyConstantFoldingWithDequantizeLinear(1, 3, 1, graph, session_options, *logger_); + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); // set SessionOptionsEnableQuantQDQ to disable it + expected_op_counts["DequantizeLinear"] = 1; ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); - VerifyConstantFoldingWithDequantizeLinear(1, 1, 1, graph, session_options, *logger_); + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); +} + +// model with 2 QDQ node units that can be constant folded as they are simple DQ -> Node -> Q where DQ and Node have +// single consumer and do not produce graph outputs. Node is deterministic. +// there are also other DQ nodes that should be ignored. +TEST_F(GraphTransformationTests, ConstantFoldingQDQNodeUnit) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["QuantizeLinear"] == 3); + ASSERT_TRUE(op_to_count["DequantizeLinear"] == 4); + ASSERT_TRUE(op_to_count["Unsqueeze"] == 1); + ASSERT_TRUE(op_to_count["Transpose"] == 1); + + SessionOptions session_options; + + // 2 QDQ node units should be constant folded and go away + std::unordered_map expected_op_counts = {{"QuantizeLinear", 1}, + {"DequantizeLinear", 2}, + {"Transpose", 0}, + {"Unsqueeze", 0}}; + + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); +} + +// Simple QDQ Node Unit but shouldn't be constant folded as the node in the middle produces a graph output +TEST_F(GraphTransformationTests, ConstantFoldingQDQNodeUnitGraphOutput) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.graph_output.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["QuantizeLinear"] == 2); + ASSERT_TRUE(op_to_count["DequantizeLinear"] == 3); + ASSERT_TRUE(op_to_count["Unsqueeze"] == 1); + + std::unordered_map expected_op_counts = {{"QuantizeLinear", 2}, + {"DequantizeLinear", 3}, + {"Unsqueeze", 1}}; + + SessionOptions session_options; + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); } TEST_F(GraphTransformationTests, ConstantFolding_RemoveDanglingInputNodesToConstantFoldedNode) { diff --git a/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.graph_output.onnx b/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.graph_output.onnx new file mode 100644 index 0000000000000000000000000000000000000000..76b64c0af17ab7e828d3ea1a020dbb5a42d46bde GIT binary patch literal 1335 zcmbVM(TdYR6zy~)jaO{Lu)EM8Y&W1tz?x=}UB!o5_QfKs1%2^lNS#p!+nFRYfptH? zPtaH2{0Kk6Kk`KPzLY()rA`zNjF#lTAAM~R~p8&oi z=*nEHcS+7DiO9e_%Q6ak_L6#?%ozA~`hw+>q>>AbkL|Tnw#0kJ1xs>rR_L_Q_cNN2 zLi3p-v`9(hN1Ha1R2D^^(q&3z)8-T z8a9S#*!+T8@vSG6GTP&!r}IQ8gT{I?)qHLQ&*3WqtHCzndRIA;% zZrRSZtY4d5JwKd5Z6$1y8N0>~3qkyraVUugN5PJH!JW#c_vVtaSs1v;YmV0624mC+ z4?HxgzZAv}=lo%KXdwP^J&e)Muxl{v*^He#KR!MV_sniWab6wV|@aUUWvR*MviDybz{SRy#75CyRaQ8z?LB@%4dj2t|L5y^}_9!sSM zV9gO&aRfFTfva!=HZV@f(E6vTQbi-hZ@zio`}yXLY1+3qV<-JM4Va{-#JS=32YvtP zuNM47(BVQR9|A!_?Bjv@tzaSSYum_OlKdDrj_!KVE7?i-fGDIuo>Oa75CJQ!l0N*_mp;MlQ? z+wPXOd2KL)6OxQ-BQ&aiMa|*e=V3^~9%Vf#0+vKPAw3>Rx={5O@EyVBe4CWsQX~?m zI*r)r=m2v$#cpf*PYh>W8Qy!r5+bAudG@I{J)FK0JPO_il5+MsOW>G`g$n!!f&$GZ zv39E>TjmhPj#C8)<6`5l8vPdp>xk{$Wx83vRyEtdpo8MJyQOX3Z7?Dt7D@+=t1Cet zFJs)Kx@RqtE8!{mcB(k8uY_|no5b3!icBk^3UD1I%%Dk`8bjEJ_W!~2NKHQd0h*Q%O8)4K`DLW?I{`wncI$hM$vBDET78QGe-p!2N6c*PO7#MZUH%;$2=R#;sm99Ln- J2Gp#Ze*y6Tiv0ip literal 0 HcmV?d00001