From f6df36b68b59b88ac2f6f6fa98a171fde28b8793 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 24 May 2019 19:42:02 -0700 Subject: [PATCH] Add rewrite rule to handle Relu + Clip (#1105) * Remove Relu if followed by Clip. Update Clip 'min' if necessary. Add unit test. * Rename to match behaviour a little better. * Update to match latest RewriteRule interface --- .../core/optimizer/graph_transformer_utils.cc | 32 +++++---- .../core/optimizer/relu_clip_fusion.cc | 53 ++++++++++++++ onnxruntime/core/optimizer/relu_clip_fusion.h | 29 ++++++++ .../test/optimizer/graph_transform_test.cc | 69 +++++++++++++++++++ 4 files changed, 168 insertions(+), 15 deletions(-) create mode 100644 onnxruntime/core/optimizer/relu_clip_fusion.cc create mode 100644 onnxruntime/core/optimizer/relu_clip_fusion.h diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 593bd58c2e..5c03fc3538 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -12,6 +12,7 @@ #include "core/optimizer/gemm_activation_fusion.h" #include "core/optimizer/matmul_add_fusion.h" #include "core/optimizer/dropout_elimination.h" +#include "core/optimizer/relu_clip_fusion.h" namespace onnxruntime { @@ -30,6 +31,7 @@ std::vector> GenerateRewriteRules(TransformerLevel rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); break; case TransformerLevel::Level2: @@ -114,21 +116,21 @@ std::vector> GenerateTransformers(TransformerL } return transformers; } - std::vector> filtered_list; - // If the rule-based transformer is not empty, it should be included in the custom transformer list below. - if (rule_transformer != nullptr) { - filtered_list.emplace_back(std::move(rule_transformer)); - } - // pick custom transformers enabled for this session - for (const auto& t_name : transformers_and_rules_to_enable) { - std::for_each(transformers.begin(), transformers.end(), - [&](std::unique_ptr& item) { - if ((item != nullptr) && (item->Name() == t_name)) { - filtered_list.push_back(std::move(item)); - } - }); - } - return filtered_list; + std::vector> filtered_list; + // If the rule-based transformer is not empty, it should be included in the custom transformer list below. + if (rule_transformer != nullptr) { + filtered_list.emplace_back(std::move(rule_transformer)); + } + // pick custom transformers enabled for this session + for (const auto& t_name : transformers_and_rules_to_enable) { + std::for_each(transformers.begin(), transformers.end(), + [&](std::unique_ptr& item) { + if ((item != nullptr) && (item->Name() == t_name)) { + filtered_list.push_back(std::move(item)); + } + }); + } + return filtered_list; } } // namespace transformer_utils diff --git a/onnxruntime/core/optimizer/relu_clip_fusion.cc b/onnxruntime/core/optimizer/relu_clip_fusion.cc new file mode 100644 index 0000000000..4e82916d4e --- /dev/null +++ b/onnxruntime/core/optimizer/relu_clip_fusion.cc @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/relu_clip_fusion.h" +#include "core/graph/graph.h" +#include "core/graph/graph_utils.h" +#include "core/graph/op.h" + +namespace onnxruntime { + +Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const { + // get the following Clip node before we delete the Relu node + const auto& next_node = *node.OutputNodesBegin(); + + if (graph_utils::RemoveNode(graph, node)) { + // update the following Clip node if the 'min' is < 0.f to set it to 0.f + // this essentially fuses the Relu and Clip + // if the Clip 'min' is >= 0.f no change is required as Relu would have set the min to 0.f + if (graph_utils::GetNodeAttribute(next_node, "min")->f() < 0.f) { + auto* mutable_next_node = graph.GetNode(next_node.Index()); + mutable_next_node->ClearAttribute("min"); + mutable_next_node->AddAttribute("min", 0.f); + } + + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + } + + return Status::OK(); +} + +bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6})) { + return false; + } + + if (!graph_utils::IsSingleInSingleOutNode(node) || + graph.IsNodeOutputsInGraphOutputs(node)) { + return false; + } + + // If the Relu is followed by a Clip node the Relu is redundant and can be removed + // as Clip will apply the minimum. If the Clip 'min' value is < 0 we need + // to update it to 0 to apply what the Relu would have done. We do that in Apply. + const auto& next_node = *node.OutputNodesBegin(); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6}) || + next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { + return false; + } + + return true; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/relu_clip_fusion.h b/onnxruntime/core/optimizer/relu_clip_fusion.h new file mode 100644 index 0000000000..2b90e9c3d8 --- /dev/null +++ b/onnxruntime/core/optimizer/relu_clip_fusion.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +/** +@Class FuseReluClip + +Rewrite rule that merges a Relu operator with a following Clip operator. +*/ +class FuseReluClip : public RewriteRule { + public: + FuseReluClip() noexcept : RewriteRule("FuseReluClip") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Relu"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 1d2facbbd7..e8a33e5fb6 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/session/inference_session.h" +#include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/optimizer/graph_transformer.h" @@ -16,6 +17,7 @@ #include "core/optimizer/conv_activation_fusion.h" #include "core/optimizer/matmul_add_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" +#include "core/optimizer/relu_clip_fusion.h" #include "core/framework/data_types.h" #include "core/framework/ml_value.h" #include "core/util/math.h" @@ -418,5 +420,72 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) { ASSERT_EQ(expected_values_prod, found); } +TEST(GraphTransformationTests, ReluClipFusion) { + Model model("ReluClipFusion"); + auto& graph = model.MainGraph(); + + std::vector inputs; + std::vector outputs; + + TypeProto input_tensor_type; + input_tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + // 3 paths in the model, each with Relu followed by Clip + // One has a Clip with min of 0 (remove Relu) + // One have a Clip with a min > 1 (remove Relu) + // One has a Clip with min < 0 (remove Relu and update Clip 'min' to 0) + auto& input0 = graph.GetOrCreateNodeArg("input_0", &input_tensor_type); + auto& input1 = graph.GetOrCreateNodeArg("input_1", &input_tensor_type); + auto& input2 = graph.GetOrCreateNodeArg("input_2", &input_tensor_type); + + auto& relu0_output = graph.GetOrCreateNodeArg("relu0_output", &input_tensor_type); + auto& relu1_output = graph.GetOrCreateNodeArg("relu1_output", &input_tensor_type); + auto& relu2_output = graph.GetOrCreateNodeArg("relu2_output", &input_tensor_type); + + auto& clip0_output = graph.GetOrCreateNodeArg("clip0_output", &input_tensor_type); + auto& clip1_output = graph.GetOrCreateNodeArg("clip1_output", &input_tensor_type); + auto& clip2_output = graph.GetOrCreateNodeArg("clip2_output", &input_tensor_type); + + graph.AddNode("relu0", "Relu", "Relu to eliminate", {&input0}, {&relu0_output}); + graph.AddNode("relu1", "Relu", "Relu to not eliminate", {&input1}, {&relu1_output}); + graph.AddNode("relu2", "Relu", "Relu to eliminate and update 'min' of following Clip", {&input2}, {&relu2_output}); + + auto& clip0 = graph.AddNode("clip0", "Clip", "Clip with min 0", {&relu0_output}, {&clip0_output}); + clip0.AddAttribute("min", 0.f); + clip0.AddAttribute("max", 1.f); + + auto& clip1 = graph.AddNode("clip1", "Clip", "Clip with min 1", {&relu1_output}, {&clip1_output}); + clip1.AddAttribute("min", 1.f); + clip1.AddAttribute("max", 1.f); + + auto& clip2 = graph.AddNode("clip2", "Clip", "Clip with min -1", {&relu2_output}, {&clip2_output}); + clip2.AddAttribute("min", -1.f); + clip2.AddAttribute("max", 1.f); + + auto status = graph.Resolve(); + EXPECT_EQ(status, Status::OK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Relu"] == 3); + + auto rule_transformer_L1 = std::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(std::make_unique()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Relu"] == 0); + + // make sure the Clip nodes were updated to have a 'min' >= 0 + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Clip") { + auto* min = graph_utils::GetNodeAttribute(node, "min"); + ASSERT_TRUE(min->f() >= 0.f); + } + } +} + } // namespace test } // namespace onnxruntime