mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Consolidate Identical Children Nodes (#14026)
### Description In case where Q have multiple DQ children, we want to keep only 1 DQ. The only remaining DQ's will channel its output to deleted DQ children's outputs. ex Q->N(DQ). => Q->DQ ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
d0c5ffd5f7
commit
babc1323e3
5 changed files with 257 additions and 3 deletions
|
|
@ -66,6 +66,7 @@
|
|||
#include "core/optimizer/slice_elimination.h"
|
||||
#include "core/optimizer/transpose_optimizer/ort_transpose_optimizer.h"
|
||||
#include "core/optimizer/unsqueeze_elimination.h"
|
||||
#include "core/optimizer/identical_children_consolidation.h"
|
||||
#ifdef ENABLE_TRAINING
|
||||
#include "orttraining/core/optimizer/bitmask_dropout_replacement.h"
|
||||
#include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h"
|
||||
|
|
@ -193,6 +194,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
// Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for
|
||||
// CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by
|
||||
// default, CSE will not merge them, because the different initializers are represented by different NodeArg.
|
||||
transformers.emplace_back(std::make_unique<IdenticalChildrenConsolidation>());
|
||||
transformers.emplace_back(std::make_unique<ConstantSharing>());
|
||||
transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>());
|
||||
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq));
|
||||
|
|
|
|||
124
onnxruntime/core/optimizer/identical_children_consolidation.cc
Normal file
124
onnxruntime/core/optimizer/identical_children_consolidation.cc
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/identical_children_consolidation.h"
|
||||
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
Status IdenticalChildrenConsolidation::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const {
|
||||
GraphViewer const graph_viewer(graph);
|
||||
for (auto node_index : graph_viewer.GetNodesInTopologicalOrder()) {
|
||||
Node* node = graph.GetNode(node_index);
|
||||
if (!IsSupportedParentNode(node)) {
|
||||
continue;
|
||||
}
|
||||
for (auto supported_op : supported_ops.at(node->OpType())) {
|
||||
for (auto twin_group : DivideIdenticalChildrenIntoGroups(graph, node, supported_op)) {
|
||||
// If there is no twins in the group, skip it.
|
||||
if (twin_group.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
Node* first_twin = graph.GetNode(twin_group[0]);
|
||||
for (size_t i = 1; i < twin_group.size(); i++) {
|
||||
Node* other_twin = graph.GetNode(twin_group[i]);
|
||||
if (graph.NodeProducesGraphOutput(*other_twin)) {
|
||||
continue;
|
||||
}
|
||||
graph_utils::ReplaceDownstreamNodeInput(graph, *other_twin, 0, *first_twin, 0);
|
||||
graph_utils::RemoveNode(graph, *other_twin);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IdenticalChildrenConsolidation::IsSupportedParentNode(const Node* node) const {
|
||||
return node != nullptr && supported_ops.count(node->OpType()) != 0 && node->GetOutputEdgesCount() > 1;
|
||||
}
|
||||
|
||||
std::vector<std::vector<NodeIndex>> IdenticalChildrenConsolidation::DivideIdenticalChildrenIntoGroups(
|
||||
const Graph& graph,
|
||||
Node* node,
|
||||
const string_view& op) const {
|
||||
unordered_map<string_view, std::vector<NodeIndex>> identical_children_map;
|
||||
for (auto i = node->OutputEdgesBegin(); i != node->OutputEdgesEnd(); ++i) {
|
||||
if (i->GetNode().OpType() == op) {
|
||||
identical_children_map[IdentityBuilder(graph, i->GetNode())].push_back(i->GetNode().Index());
|
||||
}
|
||||
}
|
||||
std::vector<std::vector<NodeIndex>> groups;
|
||||
for (auto& identical_children : identical_children_map) {
|
||||
if (identical_children.first != ignore_identity) {
|
||||
groups.push_back(std::move(identical_children.second));
|
||||
}
|
||||
}
|
||||
return groups;
|
||||
}
|
||||
|
||||
string_view IdenticalChildrenConsolidation::IdentityBuilder(const Graph& graph, const Node& node) const {
|
||||
std::string identity;
|
||||
for (const auto* input_def : node.InputDefs()) {
|
||||
if (input_def->Exists() && !input_def->Name().empty()) {
|
||||
auto name = input_def->Name();
|
||||
if (graph_utils::NodeArgIsConstant(graph, *input_def)) {
|
||||
if (optimizer_utils::IsScalar(*input_def)) {
|
||||
const auto* data = graph_utils::GetConstantInitializer(graph, name);
|
||||
identity.append(constant_prefix);
|
||||
Initializer value{*data, graph.ModelPath()};
|
||||
switch (static_cast<TensorProto::DataType>(data->data_type())) {
|
||||
case TensorProto::DataType::TensorProto_DataType_INT8:
|
||||
identity.append(std::to_string(value.data<int8_t>()[0]));
|
||||
break;
|
||||
case TensorProto::DataType::TensorProto_DataType_INT16:
|
||||
identity.append(std::to_string(value.data<int16_t>()[0]));
|
||||
break;
|
||||
case TensorProto::DataType::TensorProto_DataType_INT32:
|
||||
identity.append(std::to_string(value.data<int32_t>()[0]));
|
||||
break;
|
||||
case TensorProto::DataType::TensorProto_DataType_UINT8:
|
||||
identity.append(std::to_string(value.data<uint8_t>()[0]));
|
||||
break;
|
||||
case TensorProto::DataType::TensorProto_DataType_UINT16:
|
||||
identity.append(std::to_string(value.data<uint16_t>()[0]));
|
||||
break;
|
||||
case TensorProto::DataType::TensorProto_DataType_BOOL:
|
||||
identity.append(std::to_string(value.data<bool>()[0]));
|
||||
break;
|
||||
case TensorProto::DataType::TensorProto_DataType_INT64:
|
||||
identity.append(std::to_string(value.data<int64_t>()[0]));
|
||||
break;
|
||||
case TensorProto::DataType::TensorProto_DataType_UINT32:
|
||||
identity.append(std::to_string(value.data<uint32_t>()[0]));
|
||||
break;
|
||||
case TensorProto::DataType::TensorProto_DataType_UINT64:
|
||||
identity.append(std::to_string(value.data<uint64_t>()[0]));
|
||||
break;
|
||||
case TensorProto::DataType::TensorProto_DataType_FLOAT:
|
||||
identity.append(std::to_string(value.data<float>()[0]));
|
||||
break;
|
||||
case TensorProto::DataType::TensorProto_DataType_DOUBLE:
|
||||
identity.append(std::to_string(value.data<double>()[0]));
|
||||
break;
|
||||
case TensorProto::DataType::TensorProto_DataType_STRING:
|
||||
identity.append(value.data<std::string>()[0]);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// TODO: handle non-scalar constant inputs, using checksum or something else
|
||||
return ignore_identity;
|
||||
}
|
||||
} else {
|
||||
identity.append(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
return {identity.append("####")};
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
using std::string_view;
|
||||
using std::unordered_map;
|
||||
using std::unordered_set;
|
||||
using ONNX_NAMESPACE::TensorProto;
|
||||
|
||||
/**
|
||||
* @Class IdenticalChildrenConsolidation
|
||||
*
|
||||
* This transformer consolidates identical children nodes in a graph. The consolidate children
|
||||
* Must have the same parent and have edges with same attributes expect different destination node.
|
||||
* Currently, it only supports nodes with single input and single output and the following node
|
||||
* types from supported_ops list and supported_children_ops list.
|
||||
*
|
||||
* For example, the following graph
|
||||
*
|
||||
* [supported_parent_ops]
|
||||
* / \
|
||||
* [supported_children_ops] [supported_children_ops]
|
||||
* | |
|
||||
* [grandchildren_a] [grandchildren_b]
|
||||
*
|
||||
* will be transformed to:
|
||||
*
|
||||
* [supported_parent_ops]
|
||||
* |
|
||||
* [supported_children_ops]
|
||||
* / \
|
||||
* [grandchildren_a] [grandchildren_b]
|
||||
*/
|
||||
class IdenticalChildrenConsolidation : public GraphTransformer {
|
||||
public:
|
||||
IdenticalChildrenConsolidation() : GraphTransformer("IdenticalChildrenConsolidation") {}
|
||||
|
||||
private:
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
bool IsSupportedParentNode(const Node* node) const;
|
||||
std::vector<std::vector<NodeIndex> > DivideIdenticalChildrenIntoGroups(const Graph& graph, Node* node, const string_view& op) const;
|
||||
string_view IdentityBuilder(const Graph& graph, const Node& node) const;
|
||||
|
||||
unordered_map<string_view, unordered_set<string_view> > supported_ops = {
|
||||
{"DequantizeLinear", {"QuantizeLinear"}},
|
||||
{"QuantizeLinear", {"DequantizeLinear"}}};
|
||||
string_view constant_prefix = "ItIsSpecialConstantPrefix_";
|
||||
string_view ignore_identity = "IgNoReD_IdEnTiTy";
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -358,6 +358,43 @@ GetQDQTestCaseFn BuildBinaryOpTestCase(const std::vector<int64_t>& input_shape,
|
|||
};
|
||||
}
|
||||
|
||||
template <typename InputType, typename OutputType>
|
||||
GetQDQTestCaseFn BuildConsolidationTestCase(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const int64_t& axis) {
|
||||
return [input_shape, axis](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<float>(input_shape,std::numeric_limits<float>::min(),std::numeric_limits<float>::max());
|
||||
InputType dq_zp = std::numeric_limits<InputType>::max() / 2;
|
||||
OutputType q_zp = std::numeric_limits<OutputType>::max() / 2;
|
||||
auto* upper_dq_output = builder.MakeIntermediate();
|
||||
auto* upper_q_output = builder.MakeIntermediate();
|
||||
builder.AddQuantizeLinearNode<InputType>(input_arg, .003f, q_zp, upper_q_output);
|
||||
builder.AddDequantizeLinearNode<InputType>(upper_q_output, .003f, dq_zp, upper_dq_output);
|
||||
|
||||
// add Split
|
||||
|
||||
auto* split_output_1 = builder.MakeIntermediate();
|
||||
auto* split_output_2 = builder.MakeIntermediate();
|
||||
auto* split_output_3 = builder.MakeIntermediate();
|
||||
Node& split_node = builder.AddNode("Split", {upper_dq_output}, {split_output_1, split_output_2, split_output_3});
|
||||
split_node.AddAttribute("axis", axis);
|
||||
|
||||
// add Q
|
||||
auto* lower_q_output_1 = builder.MakeIntermediate();
|
||||
auto* lower_q_output_2 = builder.MakeIntermediate();
|
||||
auto* lower_q_output_3 = builder.MakeIntermediate();
|
||||
builder.AddQuantizeLinearNode<OutputType>(split_output_1, .003f, q_zp, lower_q_output_1);
|
||||
builder.AddQuantizeLinearNode<OutputType>(split_output_2, .003f, q_zp, lower_q_output_2);
|
||||
builder.AddQuantizeLinearNode<OutputType>(split_output_3, .003f, q_zp, lower_q_output_3);
|
||||
auto* q_split_output_1 = builder.MakeOutput();
|
||||
auto* q_split_output_2 = builder.MakeOutput();
|
||||
auto* q_split_output_3 = builder.MakeOutput();
|
||||
builder.AddDequantizeLinearNode<OutputType>(lower_q_output_1, .003f, dq_zp, q_split_output_1);
|
||||
builder.AddDequantizeLinearNode<OutputType>(lower_q_output_2, .003f, dq_zp, q_split_output_2);
|
||||
builder.AddDequantizeLinearNode<OutputType>(lower_q_output_3, .003f, dq_zp, q_split_output_3);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename InputType, typename OutputType>
|
||||
GetQDQTestCaseFn BuildQDQSplitTestCase(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
|
|
|
|||
|
|
@ -764,6 +764,7 @@ TEST(QDQTransformerTests, Gather) {
|
|||
test_case({12, 37}, {24, 12});
|
||||
}
|
||||
|
||||
// Because split isn't one the supported ops, this will stay the same
|
||||
TEST(QDQTransformerTests, Split) {
|
||||
auto test_case = [&](const std::vector<int64_t>& input_shape, const int64_t& axis) {
|
||||
auto check_graph = [&](InferenceSessionWrapper& session) {
|
||||
|
|
@ -779,20 +780,55 @@ TEST(QDQTransformerTests, Split) {
|
|||
};
|
||||
test_case({6, 18, 54}, 0);
|
||||
}
|
||||
|
||||
// Because split isn't one the supported ops, this will stay the same
|
||||
TEST(QDQTransformerTests, Split_without_IdenticalChildrenConsolidation) {
|
||||
auto test_case = [&](const std::vector<int64_t>& input_shape, const int64_t& axis) {
|
||||
auto check_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["Split"], 1);
|
||||
EXPECT_EQ(op_to_count["QuantizeLinear"], 1);
|
||||
EXPECT_EQ(op_to_count["DequantizeLinear"], 3);
|
||||
};
|
||||
TransformerTester(BuildConsolidationTestCase<int8_t, int8_t>(input_shape, axis),
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2, 12, {}, {}, nullptr, {},
|
||||
{"IdenticalChildrenConsolidation"});
|
||||
};
|
||||
test_case({6, 18, 54}, 0);
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, Split_with_IdenticalChildrenConsolidation) {
|
||||
auto test_case = [&](const std::vector<int64_t>& input_shape, const int64_t& axis) {
|
||||
auto check_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["Split"], 1);
|
||||
EXPECT_EQ(op_to_count["QuantizeLinear"], 1);
|
||||
EXPECT_EQ(op_to_count["DequantizeLinear"], 3);
|
||||
};
|
||||
TransformerTester(BuildConsolidationTestCase<int8_t, int8_t>(input_shape, axis),
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2);
|
||||
};
|
||||
test_case({6, 18, 54}, 0);
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, Where) {
|
||||
auto test_case = [&](const std::vector<int64_t>& cond_shape, const std::vector<int64_t>& x_shape,const std::vector<int64_t>& y_shape) {
|
||||
auto test_case = [&](const std::vector<int64_t>& cond_shape, const std::vector<int64_t>& x_shape, const std::vector<int64_t>& y_shape) {
|
||||
auto check_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["com.microsoft.QLinearWhere"], 1);
|
||||
EXPECT_EQ(op_to_count["QuantizeLinear"], 0);
|
||||
EXPECT_EQ(op_to_count["DequantizeLinear"], 0);
|
||||
};
|
||||
TransformerTester(BuildQDQWhereTestCase<int8_t>(cond_shape,x_shape,y_shape),
|
||||
TransformerTester(BuildQDQWhereTestCase<int8_t>(cond_shape, x_shape, y_shape),
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2);
|
||||
};
|
||||
test_case({1},{1},{1});
|
||||
test_case({1}, {1}, {1});
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, Transpose) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue