diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 8d76a160b7..0f7c8c3b0c 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -66,6 +66,7 @@ #include "core/optimizer/slice_elimination.h" #include "core/optimizer/transpose_optimizer/ort_transpose_optimizer.h" #include "core/optimizer/unsqueeze_elimination.h" +#include "core/optimizer/identical_children_consolidation.h" #ifdef ENABLE_TRAINING #include "orttraining/core/optimizer/bitmask_dropout_replacement.h" #include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h" @@ -193,6 +194,7 @@ InlinedVector> GenerateTransformers( // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by // default, CSE will not merge them, because the different initializers are represented by different NodeArg. + transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq)); diff --git a/onnxruntime/core/optimizer/identical_children_consolidation.cc b/onnxruntime/core/optimizer/identical_children_consolidation.cc new file mode 100644 index 0000000000..17f01cebcd --- /dev/null +++ b/onnxruntime/core/optimizer/identical_children_consolidation.cc @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/identical_children_consolidation.h" + +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { +Status IdenticalChildrenConsolidation::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const { + GraphViewer const graph_viewer(graph); + for (auto node_index : graph_viewer.GetNodesInTopologicalOrder()) { + Node* node = graph.GetNode(node_index); + if (!IsSupportedParentNode(node)) { + continue; + } + for (auto supported_op : supported_ops.at(node->OpType())) { + for (auto twin_group : DivideIdenticalChildrenIntoGroups(graph, node, supported_op)) { + // If there is no twins in the group, skip it. + if (twin_group.size() <= 1) { + continue; + } + Node* first_twin = graph.GetNode(twin_group[0]); + for (size_t i = 1; i < twin_group.size(); i++) { + Node* other_twin = graph.GetNode(twin_group[i]); + if (graph.NodeProducesGraphOutput(*other_twin)) { + continue; + } + graph_utils::ReplaceDownstreamNodeInput(graph, *other_twin, 0, *first_twin, 0); + graph_utils::RemoveNode(graph, *other_twin); + modified = true; + } + } + } + } + return Status::OK(); +} + +bool IdenticalChildrenConsolidation::IsSupportedParentNode(const Node* node) const { + return node != nullptr && supported_ops.count(node->OpType()) != 0 && node->GetOutputEdgesCount() > 1; +} + +std::vector> IdenticalChildrenConsolidation::DivideIdenticalChildrenIntoGroups( + const Graph& graph, + Node* node, + const string_view& op) const { + unordered_map> identical_children_map; + for (auto i = node->OutputEdgesBegin(); i != node->OutputEdgesEnd(); ++i) { + if (i->GetNode().OpType() == op) { + identical_children_map[IdentityBuilder(graph, i->GetNode())].push_back(i->GetNode().Index()); + } + } + std::vector> groups; + for (auto& identical_children : identical_children_map) { + if (identical_children.first != ignore_identity) { + groups.push_back(std::move(identical_children.second)); + } + } + return groups; +} + +string_view IdenticalChildrenConsolidation::IdentityBuilder(const Graph& graph, const Node& node) const { + std::string identity; + for (const auto* input_def : node.InputDefs()) { + if (input_def->Exists() && !input_def->Name().empty()) { + auto name = input_def->Name(); + if (graph_utils::NodeArgIsConstant(graph, *input_def)) { + if (optimizer_utils::IsScalar(*input_def)) { + const auto* data = graph_utils::GetConstantInitializer(graph, name); + identity.append(constant_prefix); + Initializer value{*data, graph.ModelPath()}; + switch (static_cast(data->data_type())) { + case TensorProto::DataType::TensorProto_DataType_INT8: + identity.append(std::to_string(value.data()[0])); + break; + case TensorProto::DataType::TensorProto_DataType_INT16: + identity.append(std::to_string(value.data()[0])); + break; + case TensorProto::DataType::TensorProto_DataType_INT32: + identity.append(std::to_string(value.data()[0])); + break; + case TensorProto::DataType::TensorProto_DataType_UINT8: + identity.append(std::to_string(value.data()[0])); + break; + case TensorProto::DataType::TensorProto_DataType_UINT16: + identity.append(std::to_string(value.data()[0])); + break; + case TensorProto::DataType::TensorProto_DataType_BOOL: + identity.append(std::to_string(value.data()[0])); + break; + case TensorProto::DataType::TensorProto_DataType_INT64: + identity.append(std::to_string(value.data()[0])); + break; + case TensorProto::DataType::TensorProto_DataType_UINT32: + identity.append(std::to_string(value.data()[0])); + break; + case TensorProto::DataType::TensorProto_DataType_UINT64: + identity.append(std::to_string(value.data()[0])); + break; + case TensorProto::DataType::TensorProto_DataType_FLOAT: + identity.append(std::to_string(value.data()[0])); + break; + case TensorProto::DataType::TensorProto_DataType_DOUBLE: + identity.append(std::to_string(value.data()[0])); + break; + case TensorProto::DataType::TensorProto_DataType_STRING: + identity.append(value.data()[0]); + break; + default: + break; + } + } else { + // TODO: handle non-scalar constant inputs, using checksum or something else + return ignore_identity; + } + } else { + identity.append(name); + } + } + } + return {identity.append("####")}; +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/identical_children_consolidation.h b/onnxruntime/core/optimizer/identical_children_consolidation.h new file mode 100644 index 0000000000..c391470aeb --- /dev/null +++ b/onnxruntime/core/optimizer/identical_children_consolidation.h @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +using std::string_view; +using std::unordered_map; +using std::unordered_set; +using ONNX_NAMESPACE::TensorProto; + +/** + * @Class IdenticalChildrenConsolidation + * + * This transformer consolidates identical children nodes in a graph. The consolidate children + * Must have the same parent and have edges with same attributes expect different destination node. + * Currently, it only supports nodes with single input and single output and the following node + * types from supported_ops list and supported_children_ops list. + * + * For example, the following graph + * + * [supported_parent_ops] + * / \ + * [supported_children_ops] [supported_children_ops] + * | | + * [grandchildren_a] [grandchildren_b] + * + * will be transformed to: + * + * [supported_parent_ops] + * | + * [supported_children_ops] + * / \ + * [grandchildren_a] [grandchildren_b] + */ +class IdenticalChildrenConsolidation : public GraphTransformer { + public: + IdenticalChildrenConsolidation() : GraphTransformer("IdenticalChildrenConsolidation") {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + bool IsSupportedParentNode(const Node* node) const; + std::vector > DivideIdenticalChildrenIntoGroups(const Graph& graph, Node* node, const string_view& op) const; + string_view IdentityBuilder(const Graph& graph, const Node& node) const; + + unordered_map > supported_ops = { + {"DequantizeLinear", {"QuantizeLinear"}}, + {"QuantizeLinear", {"DequantizeLinear"}}}; + string_view constant_prefix = "ItIsSpecialConstantPrefix_"; + string_view ignore_identity = "IgNoReD_IdEnTiTy"; +}; +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index f9a21135fa..9f98dcd691 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -358,6 +358,43 @@ GetQDQTestCaseFn BuildBinaryOpTestCase(const std::vector& input_shape, }; } +template +GetQDQTestCaseFn BuildConsolidationTestCase( + const std::vector& input_shape, + const int64_t& axis) { + return [input_shape, axis](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input_shape,std::numeric_limits::min(),std::numeric_limits::max()); + InputType dq_zp = std::numeric_limits::max() / 2; + OutputType q_zp = std::numeric_limits::max() / 2; + auto* upper_dq_output = builder.MakeIntermediate(); + auto* upper_q_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input_arg, .003f, q_zp, upper_q_output); + builder.AddDequantizeLinearNode(upper_q_output, .003f, dq_zp, upper_dq_output); + + // add Split + + auto* split_output_1 = builder.MakeIntermediate(); + auto* split_output_2 = builder.MakeIntermediate(); + auto* split_output_3 = builder.MakeIntermediate(); + Node& split_node = builder.AddNode("Split", {upper_dq_output}, {split_output_1, split_output_2, split_output_3}); + split_node.AddAttribute("axis", axis); + + // add Q + auto* lower_q_output_1 = builder.MakeIntermediate(); + auto* lower_q_output_2 = builder.MakeIntermediate(); + auto* lower_q_output_3 = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(split_output_1, .003f, q_zp, lower_q_output_1); + builder.AddQuantizeLinearNode(split_output_2, .003f, q_zp, lower_q_output_2); + builder.AddQuantizeLinearNode(split_output_3, .003f, q_zp, lower_q_output_3); + auto* q_split_output_1 = builder.MakeOutput(); + auto* q_split_output_2 = builder.MakeOutput(); + auto* q_split_output_3 = builder.MakeOutput(); + builder.AddDequantizeLinearNode(lower_q_output_1, .003f, dq_zp, q_split_output_1); + builder.AddDequantizeLinearNode(lower_q_output_2, .003f, dq_zp, q_split_output_2); + builder.AddDequantizeLinearNode(lower_q_output_3, .003f, dq_zp, q_split_output_3); + }; +} + template GetQDQTestCaseFn BuildQDQSplitTestCase( const std::vector& input_shape, diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index bd3a2eae44..749f6fcf7b 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -764,6 +764,7 @@ TEST(QDQTransformerTests, Gather) { test_case({12, 37}, {24, 12}); } +// Because split isn't one the supported ops, this will stay the same TEST(QDQTransformerTests, Split) { auto test_case = [&](const std::vector& input_shape, const int64_t& axis) { auto check_graph = [&](InferenceSessionWrapper& session) { @@ -779,20 +780,55 @@ TEST(QDQTransformerTests, Split) { }; test_case({6, 18, 54}, 0); } + +// Because split isn't one the supported ops, this will stay the same +TEST(QDQTransformerTests, Split_without_IdenticalChildrenConsolidation) { + auto test_case = [&](const std::vector& input_shape, const int64_t& axis) { + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["Split"], 1); + EXPECT_EQ(op_to_count["QuantizeLinear"], 1); + EXPECT_EQ(op_to_count["DequantizeLinear"], 3); + }; + TransformerTester(BuildConsolidationTestCase(input_shape, axis), + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, 12, {}, {}, nullptr, {}, + {"IdenticalChildrenConsolidation"}); + }; + test_case({6, 18, 54}, 0); +} + +TEST(QDQTransformerTests, Split_with_IdenticalChildrenConsolidation) { + auto test_case = [&](const std::vector& input_shape, const int64_t& axis) { + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["Split"], 1); + EXPECT_EQ(op_to_count["QuantizeLinear"], 1); + EXPECT_EQ(op_to_count["DequantizeLinear"], 3); + }; + TransformerTester(BuildConsolidationTestCase(input_shape, axis), + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2); + }; + test_case({6, 18, 54}, 0); +} + TEST(QDQTransformerTests, Where) { - auto test_case = [&](const std::vector& cond_shape, const std::vector& x_shape,const std::vector& y_shape) { + auto test_case = [&](const std::vector& cond_shape, const std::vector& x_shape, const std::vector& y_shape) { auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["com.microsoft.QLinearWhere"], 1); EXPECT_EQ(op_to_count["QuantizeLinear"], 0); EXPECT_EQ(op_to_count["DequantizeLinear"], 0); }; - TransformerTester(BuildQDQWhereTestCase(cond_shape,x_shape,y_shape), + TransformerTester(BuildQDQWhereTestCase(cond_shape, x_shape, y_shape), check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; - test_case({1},{1},{1}); + test_case({1}, {1}, {1}); } TEST(QDQTransformerTests, Transpose) {