Add rewrite rule to handle Relu + Clip (#1105)

* Remove Relu if followed by Clip. Update Clip 'min' if necessary.
Add unit test.

* Rename to match behaviour a little better.

* Update to match latest RewriteRule interface
This commit is contained in:
Scott McKay 2019-05-24 19:42:02 -07:00 committed by GitHub
parent b54a292ba2
commit f6df36b68b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 168 additions and 15 deletions

View file

@ -12,6 +12,7 @@
#include "core/optimizer/gemm_activation_fusion.h"
#include "core/optimizer/matmul_add_fusion.h"
#include "core/optimizer/dropout_elimination.h"
#include "core/optimizer/relu_clip_fusion.h"
namespace onnxruntime {
@ -30,6 +31,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(TransformerLevel
rules.push_back(std::make_unique<EliminateSlice>());
rules.push_back(std::make_unique<UnsqueezeElimination>());
rules.push_back(std::make_unique<EliminateDropout>());
rules.push_back(std::make_unique<FuseReluClip>());
break;
case TransformerLevel::Level2:
@ -114,21 +116,21 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
}
return transformers;
}
std::vector<std::unique_ptr<GraphTransformer>> filtered_list;
// If the rule-based transformer is not empty, it should be included in the custom transformer list below.
if (rule_transformer != nullptr) {
filtered_list.emplace_back(std::move(rule_transformer));
}
// pick custom transformers enabled for this session
for (const auto& t_name : transformers_and_rules_to_enable) {
std::for_each(transformers.begin(), transformers.end(),
[&](std::unique_ptr<GraphTransformer>& item) {
if ((item != nullptr) && (item->Name() == t_name)) {
filtered_list.push_back(std::move(item));
}
});
}
return filtered_list;
std::vector<std::unique_ptr<GraphTransformer>> filtered_list;
// If the rule-based transformer is not empty, it should be included in the custom transformer list below.
if (rule_transformer != nullptr) {
filtered_list.emplace_back(std::move(rule_transformer));
}
// pick custom transformers enabled for this session
for (const auto& t_name : transformers_and_rules_to_enable) {
std::for_each(transformers.begin(), transformers.end(),
[&](std::unique_ptr<GraphTransformer>& item) {
if ((item != nullptr) && (item->Name() == t_name)) {
filtered_list.push_back(std::move(item));
}
});
}
return filtered_list;
}
} // namespace transformer_utils

View file

@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/relu_clip_fusion.h"
#include "core/graph/graph.h"
#include "core/graph/graph_utils.h"
#include "core/graph/op.h"
namespace onnxruntime {
Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
// get the following Clip node before we delete the Relu node
const auto& next_node = *node.OutputNodesBegin();
if (graph_utils::RemoveNode(graph, node)) {
// update the following Clip node if the 'min' is < 0.f to set it to 0.f
// this essentially fuses the Relu and Clip
// if the Clip 'min' is >= 0.f no change is required as Relu would have set the min to 0.f
if (graph_utils::GetNodeAttribute(next_node, "min")->f() < 0.f) {
auto* mutable_next_node = graph.GetNode(next_node.Index());
mutable_next_node->ClearAttribute("min");
mutable_next_node->AddAttribute("min", 0.f);
}
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
return Status::OK();
}
bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6})) {
return false;
}
if (!graph_utils::IsSingleInSingleOutNode(node) ||
graph.IsNodeOutputsInGraphOutputs(node)) {
return false;
}
// If the Relu is followed by a Clip node the Relu is redundant and can be removed
// as Clip will apply the minimum. If the Clip 'min' value is < 0 we need
// to update it to 0 to apply what the Relu would have done. We do that in Apply.
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6}) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
return false;
}
return true;
}
} // namespace onnxruntime

View file

@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
/**
@Class FuseReluClip
Rewrite rule that merges a Relu operator with a following Clip operator.
*/
class FuseReluClip : public RewriteRule {
public:
FuseReluClip() noexcept : RewriteRule("FuseReluClip") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Relu"};
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
};
} // namespace onnxruntime

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/session/inference_session.h"
#include "core/graph/graph_utils.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/optimizer/graph_transformer.h"
@ -16,6 +17,7 @@
#include "core/optimizer/conv_activation_fusion.h"
#include "core/optimizer/matmul_add_fusion.h"
#include "core/optimizer/gemm_activation_fusion.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/framework/data_types.h"
#include "core/framework/ml_value.h"
#include "core/util/math.h"
@ -418,5 +420,72 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) {
ASSERT_EQ(expected_values_prod, found);
}
TEST(GraphTransformationTests, ReluClipFusion) {
Model model("ReluClipFusion");
auto& graph = model.MainGraph();
std::vector<NodeArg*> inputs;
std::vector<NodeArg*> outputs;
TypeProto input_tensor_type;
input_tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
// 3 paths in the model, each with Relu followed by Clip
// One has a Clip with min of 0 (remove Relu)
// One have a Clip with a min > 1 (remove Relu)
// One has a Clip with min < 0 (remove Relu and update Clip 'min' to 0)
auto& input0 = graph.GetOrCreateNodeArg("input_0", &input_tensor_type);
auto& input1 = graph.GetOrCreateNodeArg("input_1", &input_tensor_type);
auto& input2 = graph.GetOrCreateNodeArg("input_2", &input_tensor_type);
auto& relu0_output = graph.GetOrCreateNodeArg("relu0_output", &input_tensor_type);
auto& relu1_output = graph.GetOrCreateNodeArg("relu1_output", &input_tensor_type);
auto& relu2_output = graph.GetOrCreateNodeArg("relu2_output", &input_tensor_type);
auto& clip0_output = graph.GetOrCreateNodeArg("clip0_output", &input_tensor_type);
auto& clip1_output = graph.GetOrCreateNodeArg("clip1_output", &input_tensor_type);
auto& clip2_output = graph.GetOrCreateNodeArg("clip2_output", &input_tensor_type);
graph.AddNode("relu0", "Relu", "Relu to eliminate", {&input0}, {&relu0_output});
graph.AddNode("relu1", "Relu", "Relu to not eliminate", {&input1}, {&relu1_output});
graph.AddNode("relu2", "Relu", "Relu to eliminate and update 'min' of following Clip", {&input2}, {&relu2_output});
auto& clip0 = graph.AddNode("clip0", "Clip", "Clip with min 0", {&relu0_output}, {&clip0_output});
clip0.AddAttribute("min", 0.f);
clip0.AddAttribute("max", 1.f);
auto& clip1 = graph.AddNode("clip1", "Clip", "Clip with min 1", {&relu1_output}, {&clip1_output});
clip1.AddAttribute("min", 1.f);
clip1.AddAttribute("max", 1.f);
auto& clip2 = graph.AddNode("clip2", "Clip", "Clip with min -1", {&relu2_output}, {&clip2_output});
clip2.AddAttribute("min", -1.f);
clip2.AddAttribute("max", 1.f);
auto status = graph.Resolve();
EXPECT_EQ(status, Status::OK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Relu"] == 3);
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(std::make_unique<FuseReluClip>());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Relu"] == 0);
// make sure the Clip nodes were updated to have a 'min' >= 0
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Clip") {
auto* min = graph_utils::GetNodeAttribute(node, "min");
ASSERT_TRUE(min->f() >= 0.f);
}
}
}
} // namespace test
} // namespace onnxruntime