From 17f0ffd1c82f1bef0977b42515ccbc644597369d Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Tue, 1 Nov 2022 10:39:52 +0800 Subject: [PATCH] Support More Cases in NoOpElimination (#13460) Current NoOpElimination can support only Add node. This PR adds support for: x-0, x*1, 1*x and x/1 besides x+0 and 0+x. With this PR, all Div(x,1) and their gradients (also Div(x,1)) in Huggingface's diffusers model can be removed, which takes ~1% of compute time in total previously. --- .../core/optimizer/noop_elimination.cc | 50 ++-- onnxruntime/core/optimizer/noop_elimination.h | 5 +- .../test/optimizer/graph_transform_test.cc | 283 ++++++++++++++++++ 3 files changed, 307 insertions(+), 31 deletions(-) diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc index c01bdc42bf..b3c2991d54 100644 --- a/onnxruntime/core/optimizer/noop_elimination.cc +++ b/onnxruntime/core/optimizer/noop_elimination.cc @@ -13,14 +13,7 @@ namespace onnxruntime { /** - Eliminate no op node - handling Add op for now - Add example: - - X 0 - \ / - Add - | - Y + Eliminate no op node - supporting x+0, 0+x, x-0, x*1, 1*x and x/1 for now. */ Status NoopElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { if (graph_utils::RemoveNode(graph, node)) { @@ -31,7 +24,6 @@ Status NoopElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ } bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { - bool input0_is_initializer = graph_utils::IsConstantInitializer(graph, node.InputDefs()[0]->Name()); bool input1_is_initializer = graph_utils::IsConstantInitializer(graph, node.InputDefs()[1]->Name()); @@ -40,7 +32,13 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con return false; } - const auto* initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[input0_is_initializer ? 0 : 1]->Name()); + const auto& op_type = node.OpType(); + if ((op_type == "Sub" || op_type == "Div") && !input1_is_initializer) { + return false; + } + + const auto* initializer = + graph_utils::GetConstantInitializer(graph, node.InputDefs()[input0_is_initializer ? 0 : 1]->Name()); // if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule, // but it won't happen if the case is accepted, thus reject it @@ -59,42 +57,38 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con if (add_init.size() == 0) { return true; } + + float value = 0.0f; switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - if (*add_init.data() != 0.f) { - return false; - } + value = *add_init.data(); break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - if (math::halfToFloat(add_init.data()->val) != 0.f) { - return false; - } + value = math::halfToFloat(add_init.data()->val); break; case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - if (*add_init.data() != static_cast(0.f)) { - return false; - } + value = static_cast(*add_init.data()); break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: - if (*add_init.data() != static_cast(0)) { - return false; - } + value = static_cast(*add_init.data()); break; case ONNX_NAMESPACE::TensorProto_DataType_INT64: - if (*add_init.data() != static_cast(0)) { - return false; - } + value = static_cast(*add_init.data()); break; default: return false; } - // reject node output is graph output for now - if (!graph_utils::CanRemoveNode(graph, node, logger)) { + if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) { return false; } - return true; + if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) { + return false; + } + + // reject node output is graph output for now + return graph_utils::CanRemoveNode(graph, node, logger); } } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/noop_elimination.h b/onnxruntime/core/optimizer/noop_elimination.h index 7a11046277..a956d74774 100644 --- a/onnxruntime/core/optimizer/noop_elimination.h +++ b/onnxruntime/core/optimizer/noop_elimination.h @@ -11,15 +11,14 @@ namespace onnxruntime { @Class NoopElimination Rewrite rule that eliminates the no op node. -So far only Add node with 0 as one of its inputs is eliminated. -But this class could be the placeholder for other no op nodes in future. +Support x+0, 0+x, x-0, x*1, 1*x and x/1 for now. */ class NoopElimination : public RewriteRule { public: NoopElimination() noexcept : RewriteRule("NoopElimination") {} std::vector TargetOpTypes() const noexcept override { - return {"Add"}; + return {"Add", "Sub", "Mul", "Div"}; } private: diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index aaea49e962..70a23dd622 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -214,6 +214,289 @@ TEST_F(GraphTransformationTests, NoopElimination) { op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 1); + + auto pre_graph_checker = [&](Graph& graph) { + ASSERT_EQ(CountOpsInGraph(graph)["Add"] + CountOpsInGraph(graph)["Sub"] + CountOpsInGraph(graph)["Mul"] + + CountOpsInGraph(graph)["Div"], + 1); + }; + + auto post_graph_checker = [&](Graph& graph) { + ASSERT_EQ(CountOpsInGraph(graph)["Add"] + CountOpsInGraph(graph)["Sub"] + CountOpsInGraph(graph)["Mul"] + + CountOpsInGraph(graph)["Div"], + 0); + }; + + // x+0, float. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* input2_arg = builder.MakeInput({{3, 3}}); + auto* matmul_output = builder.MakeIntermediate(); + auto* initializer_arg = builder.MakeInitializer({}, {0.0f}); + auto* add_out = builder.MakeIntermediate(); + auto* identity_output = builder.MakeOutput(); + + builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output}); + builder.AddNode("Add", {matmul_output, initializer_arg}, {add_out}); + builder.AddNode("Identity", {add_out}, {identity_output}); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker); + } + + // 0+x, fp16. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* input2_arg = builder.MakeInput({{3, 3}}); + auto* matmul_output = builder.MakeIntermediate(); + auto* initializer_arg = builder.MakeInitializer({1}, {MLFloat16(0.0f)}); + auto* add_out = builder.MakeIntermediate(); + auto* identity_output = builder.MakeOutput(); + + builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output}); + builder.AddNode("Add", {initializer_arg, matmul_output}, {add_out}); + builder.AddNode("Identity", {add_out}, {identity_output}); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker); + } + + // x-0, double. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* input2_arg = builder.MakeInput({{3, 3}}); + auto* matmul_output = builder.MakeIntermediate(); + auto* initializer_arg = builder.MakeInitializer({1, 1}, {static_cast(0.0f)}); + auto* sub_out = builder.MakeIntermediate(); + auto* identity_output = builder.MakeOutput(); + + builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output}); + builder.AddNode("Sub", {matmul_output, initializer_arg}, {sub_out}); + builder.AddNode("Identity", {sub_out}, {identity_output}); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker); + } + + // x*1, int32. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* input2_arg = builder.MakeInput({{3, 3}}); + auto* matmul_output = builder.MakeIntermediate(); + auto* initializer_arg = builder.MakeInitializer({1, 1, 1}, {1}); + auto* mul_out = builder.MakeIntermediate(); + auto* identity_output = builder.MakeOutput(); + + builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output}); + builder.AddNode("Mul", {matmul_output, initializer_arg}, {mul_out}); + builder.AddNode("Identity", {mul_out}, {identity_output}); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker); + } + + // 1*x, int64. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* input2_arg = builder.MakeInput({{3, 3}}); + auto* matmul_output = builder.MakeIntermediate(); + auto* initializer_arg = builder.MakeInitializer({1, 1, 1, 1}, {static_cast(1)}); + auto* mul_out = builder.MakeIntermediate(); + auto* identity_output = builder.MakeOutput(); + + builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output}); + builder.AddNode("Mul", {initializer_arg, matmul_output}, {mul_out}); + builder.AddNode("Identity", {mul_out}, {identity_output}); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker); + } + + // x/1, float. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* input2_arg = builder.MakeInput({{3, 3}}); + auto* matmul_output = builder.MakeIntermediate(); + auto* initializer_arg = builder.MakeInitializer({}, {1.0f}); + auto* div_out = builder.MakeIntermediate(); + auto* identity_output = builder.MakeOutput(); + + builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output}); + builder.AddNode("Div", {matmul_output, initializer_arg}, {div_out}); + builder.AddNode("Identity", {div_out}, {identity_output}); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker); + } + + // Invalid case: x+1. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* input2_arg = builder.MakeInput({{3, 3}}); + auto* matmul_output = builder.MakeIntermediate(); + auto* initializer_arg = builder.MakeInitializer({}, {1.0f}); + auto* add_out = builder.MakeIntermediate(); + auto* identity_output = builder.MakeOutput(); + + builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output}); + builder.AddNode("Add", {matmul_output, initializer_arg}, {add_out}); + builder.AddNode("Identity", {add_out}, {identity_output}); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker); + } + + // Invalid case: initializer rank is larger. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* input2_arg = builder.MakeInput({{3, 3}}); + auto* matmul_output = builder.MakeIntermediate(); + auto* initializer_arg = builder.MakeInitializer({1, 1, 1, 1, 1}, {MLFloat16(0.0f)}); + auto* add_out = builder.MakeIntermediate(); + auto* identity_output = builder.MakeOutput(); + + builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output}); + builder.AddNode("Add", {initializer_arg, matmul_output}, {add_out}); + builder.AddNode("Identity", {add_out}, {identity_output}); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker); + } + + // Invalid case: 0-x. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* input2_arg = builder.MakeInput({{3, 3}}); + auto* matmul_output = builder.MakeIntermediate(); + auto* initializer_arg = builder.MakeInitializer({1, 1}, {static_cast(0.0f)}); + auto* sub_out = builder.MakeIntermediate(); + auto* identity_output = builder.MakeOutput(); + + builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output}); + builder.AddNode("Sub", {initializer_arg, matmul_output}, {sub_out}); + builder.AddNode("Identity", {sub_out}, {identity_output}); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker); + } + + // Invalid case: x-1. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* input2_arg = builder.MakeInput({{3, 3}}); + auto* matmul_output = builder.MakeIntermediate(); + auto* initializer_arg = builder.MakeInitializer({1, 1}, {static_cast(1.0f)}); + auto* sub_out = builder.MakeIntermediate(); + auto* identity_output = builder.MakeOutput(); + + builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output}); + builder.AddNode("Sub", {matmul_output, initializer_arg}, {sub_out}); + builder.AddNode("Identity", {sub_out}, {identity_output}); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker); + } + + // Invalid case: 0*x. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* input2_arg = builder.MakeInput({{3, 3}}); + auto* matmul_output = builder.MakeIntermediate(); + auto* initializer_arg = builder.MakeInitializer({1, 1, 1}, {0}); + auto* mul_out = builder.MakeIntermediate(); + auto* identity_output = builder.MakeOutput(); + + builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output}); + builder.AddNode("Mul", {initializer_arg, matmul_output}, {mul_out}); + builder.AddNode("Identity", {mul_out}, {identity_output}); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker); + } + + // Invalid case: output is graph output. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* input2_arg = builder.MakeInput({{3, 3}}); + auto* matmul_output = builder.MakeIntermediate(); + auto* initializer_arg = builder.MakeInitializer({1, 1, 1, 1}, {static_cast(1)}); + auto* mul_out = builder.MakeOutput(); + + builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output}); + builder.AddNode("Mul", {initializer_arg, matmul_output}, {mul_out}); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker); + } + + // Invalid case: 1/x. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* input2_arg = builder.MakeInput({{3, 3}}); + auto* matmul_output = builder.MakeIntermediate(); + auto* initializer_arg = builder.MakeInitializer({}, {1.0f}); + auto* div_out = builder.MakeIntermediate(); + auto* identity_output = builder.MakeOutput(); + + builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output}); + builder.AddNode("Div", {initializer_arg, matmul_output}, {div_out}); + builder.AddNode("Identity", {div_out}, {identity_output}); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker); + } } TEST_F(GraphTransformationTests, DropoutElimination) {