diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index d32cb87fed..80e2bbedef 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -7,6 +7,7 @@ #include "core/optimizer/utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/optimizer_execution_frame.h" +#include "core/optimizer/utils.h" #include "core/framework/op_kernel.h" #include "core/framework/tensorprotoutils.h" @@ -106,11 +107,6 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } - // avoid to constant fold DequantizeLinear for QDQ format - if (skip_dequantize_linear_ && node->OpType().compare("DequantizeLinear") == 0) { - continue; - } - ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); // Updating a node may allow shape inferencing to infer output shapes of following nodes, @@ -139,15 +135,52 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, } // Check if constant folding can be applied on this node. - if (!graph_utils::IsSupportedProvider(*node, GetCompatibleExecutionProviders()) || - !optimizer_utils::IsOperationDeterministic(node->Domain(), node->OpType()) || - // constant folding does not support executing a node that includes subgraphs (control flow operators, - // such as If/Loop/Scan, fall into this category). individual nodes in the subgraph will be processed - // by the Recurse call above - node->ContainsSubgraph() || !graph_utils::AllNodeInputsAreConstant(graph, *node, constant_inputs, excluded_initializers_)) { + const auto can_constant_fold_node = [&](const Node& n, bool skip_inputs_constant_check = false) { + return graph_utils::IsSupportedProvider(n, GetCompatibleExecutionProviders()) && + optimizer_utils::IsOperationDeterministic(n.Domain(), n.OpType()) && + // constant folding does not support executing a node that includes subgraphs (control flow operators, + // such as If/Loop/Scan, fall into this category). individual nodes in the subgraph will be processed + // by the Recurse call above + !n.ContainsSubgraph() && + (skip_inputs_constant_check || + graph_utils::AllNodeInputsAreConstant(graph, n, constant_inputs, excluded_initializers_)); + }; + + if (!can_constant_fold_node(*node)) { continue; } + // if skip_dequantize_linear is true we want to maintain QDQ node units so avoid constant folding + // DequantizeLinear unless we can fold the whole QDQ node unit + if (skip_dequantize_linear_ && node->OpType() == "DequantizeLinear") { + bool can_constant_fold_qdq_node_unit = false; + + // Simplest scenario where the whole QDQ node unit of (DQ -> X -> Q) can be constant folded is if: + // - the DQ node does not produce a graph output, and its output is only consumed by X + // - X is a deterministic node with a single input and single output + // - the output from X is not a graph output and is only consumed by a Q node + if (optimizer_utils::CheckOutputEdges(graph, *node, 1)) { // DQ does not produce graph output, single consumer + const Node& node_x = *node->OutputNodesBegin(); + if (node_x.InputDefs().size() == 1 && + node_x.OutputDefs().size() == 1 && + optimizer_utils::CheckOutputEdges(graph, node_x, 1)) { + const Node& probably_q = *node_x.OutputNodesBegin(); + + if (probably_q.OpType() == "QuantizeLinear") { + // the inputs to these nodes are not const yet, but will be if we constant fold, + // so set skip_const_check to simulate that having happened + constexpr bool skip_const_check = true; + can_constant_fold_qdq_node_unit = can_constant_fold_node(node_x, skip_const_check) && + can_constant_fold_node(probably_q, skip_const_check); + } + } + } + + if (!can_constant_fold_qdq_node_unit) { + continue; + } + } + #if !defined(DISABLE_SPARSE_TENSORS) // Create execution frame for executing constant nodes. OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_, diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 813614493b..ca5011e28b 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -725,9 +725,7 @@ TEST_F(GraphTransformationTests, ConstantFoldingWithScalarShapeToInitializer) { ASSERT_TRUE(op_to_count["Add"] == 1); } -static void VerifyConstantFoldingWithDequantizeLinear(int quantize_linear_count, - int dequantize_linear_count, - int conv_count, +static void VerifyConstantFoldingWithDequantizeLinear(const std::unordered_map& expected_op_count, Graph& graph, SessionOptions& session_options, const Logger& logger) { @@ -748,9 +746,15 @@ static void VerifyConstantFoldingWithDequantizeLinear(int quantize_linear_count, ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, logger)); std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["QuantizeLinear"] == quantize_linear_count); - ASSERT_TRUE(op_to_count["DequantizeLinear"] == dequantize_linear_count); - ASSERT_TRUE(op_to_count["Conv"] == conv_count); + for (const auto& entry : expected_op_count) { + if (entry.second == 0) { + ASSERT_TRUE(op_to_count.find(entry.first) == op_to_count.end()) + << entry.first << " should not exist in the graph"; + } else { + ASSERT_TRUE(op_to_count[entry.first] == entry.second) + << entry.first << " mismatch. Expected:" << entry.second << " Got:" << op_to_count[entry.first]; + } + } } TEST_F(GraphTransformationTests, ConstantFoldingWithDequantizeLinear) { @@ -763,17 +767,66 @@ TEST_F(GraphTransformationTests, ConstantFoldingWithDequantizeLinear) { ASSERT_TRUE(op_to_count["DequantizeLinear"] == 3); ASSERT_TRUE(op_to_count["Conv"] == 1); + std::unordered_map expected_op_counts = {{"QuantizeLinear", 1}, + {"DequantizeLinear", 3}, + {"Conv", 1}}; + SessionOptions session_options; // Check DequantizeLinear aren't constant folded for default setting. - VerifyConstantFoldingWithDequantizeLinear(1, 3, 1, graph, session_options, *logger_); + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); // set kOrtSessionOptionsDisableQuantQDQ to enable it explicitly ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "0")); - VerifyConstantFoldingWithDequantizeLinear(1, 3, 1, graph, session_options, *logger_); + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); // set SessionOptionsEnableQuantQDQ to disable it + expected_op_counts["DequantizeLinear"] = 1; ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); - VerifyConstantFoldingWithDequantizeLinear(1, 1, 1, graph, session_options, *logger_); + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); +} + +// model with 2 QDQ node units that can be constant folded as they are simple DQ -> Node -> Q where DQ and Node have +// single consumer and do not produce graph outputs. Node is deterministic. +// there are also other DQ nodes that should be ignored. +TEST_F(GraphTransformationTests, ConstantFoldingQDQNodeUnit) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["QuantizeLinear"] == 3); + ASSERT_TRUE(op_to_count["DequantizeLinear"] == 4); + ASSERT_TRUE(op_to_count["Unsqueeze"] == 1); + ASSERT_TRUE(op_to_count["Transpose"] == 1); + + SessionOptions session_options; + + // 2 QDQ node units should be constant folded and go away + std::unordered_map expected_op_counts = {{"QuantizeLinear", 1}, + {"DequantizeLinear", 2}, + {"Transpose", 0}, + {"Unsqueeze", 0}}; + + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); +} + +// Simple QDQ Node Unit but shouldn't be constant folded as the node in the middle produces a graph output +TEST_F(GraphTransformationTests, ConstantFoldingQDQNodeUnitGraphOutput) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/constant_folding_qdq_node_unit.graph_output.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["QuantizeLinear"] == 2); + ASSERT_TRUE(op_to_count["DequantizeLinear"] == 3); + ASSERT_TRUE(op_to_count["Unsqueeze"] == 1); + + std::unordered_map expected_op_counts = {{"QuantizeLinear", 2}, + {"DequantizeLinear", 3}, + {"Unsqueeze", 1}}; + + SessionOptions session_options; + VerifyConstantFoldingWithDequantizeLinear(expected_op_counts, graph, session_options, *logger_); } TEST_F(GraphTransformationTests, ConstantFolding_RemoveDanglingInputNodesToConstantFoldedNode) { diff --git a/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.graph_output.onnx b/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.graph_output.onnx new file mode 100644 index 0000000000..76b64c0af1 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.graph_output.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.onnx b/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.onnx new file mode 100644 index 0000000000..9943af1080 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/constant_folding_qdq_node_unit.onnx differ