mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
Add rewrite rule to handle Relu + Clip (#1105)
* Remove Relu if followed by Clip. Update Clip 'min' if necessary. Add unit test. * Rename to match behaviour a little better. * Update to match latest RewriteRule interface
This commit is contained in:
parent
b54a292ba2
commit
f6df36b68b
4 changed files with 168 additions and 15 deletions
|
|
@ -12,6 +12,7 @@
|
|||
#include "core/optimizer/gemm_activation_fusion.h"
|
||||
#include "core/optimizer/matmul_add_fusion.h"
|
||||
#include "core/optimizer/dropout_elimination.h"
|
||||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -30,6 +31,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(TransformerLevel
|
|||
rules.push_back(std::make_unique<EliminateSlice>());
|
||||
rules.push_back(std::make_unique<UnsqueezeElimination>());
|
||||
rules.push_back(std::make_unique<EliminateDropout>());
|
||||
rules.push_back(std::make_unique<FuseReluClip>());
|
||||
break;
|
||||
|
||||
case TransformerLevel::Level2:
|
||||
|
|
@ -114,21 +116,21 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
|
|||
}
|
||||
return transformers;
|
||||
}
|
||||
std::vector<std::unique_ptr<GraphTransformer>> filtered_list;
|
||||
// If the rule-based transformer is not empty, it should be included in the custom transformer list below.
|
||||
if (rule_transformer != nullptr) {
|
||||
filtered_list.emplace_back(std::move(rule_transformer));
|
||||
}
|
||||
// pick custom transformers enabled for this session
|
||||
for (const auto& t_name : transformers_and_rules_to_enable) {
|
||||
std::for_each(transformers.begin(), transformers.end(),
|
||||
[&](std::unique_ptr<GraphTransformer>& item) {
|
||||
if ((item != nullptr) && (item->Name() == t_name)) {
|
||||
filtered_list.push_back(std::move(item));
|
||||
}
|
||||
});
|
||||
}
|
||||
return filtered_list;
|
||||
std::vector<std::unique_ptr<GraphTransformer>> filtered_list;
|
||||
// If the rule-based transformer is not empty, it should be included in the custom transformer list below.
|
||||
if (rule_transformer != nullptr) {
|
||||
filtered_list.emplace_back(std::move(rule_transformer));
|
||||
}
|
||||
// pick custom transformers enabled for this session
|
||||
for (const auto& t_name : transformers_and_rules_to_enable) {
|
||||
std::for_each(transformers.begin(), transformers.end(),
|
||||
[&](std::unique_ptr<GraphTransformer>& item) {
|
||||
if ((item != nullptr) && (item->Name() == t_name)) {
|
||||
filtered_list.push_back(std::move(item));
|
||||
}
|
||||
});
|
||||
}
|
||||
return filtered_list;
|
||||
}
|
||||
|
||||
} // namespace transformer_utils
|
||||
|
|
|
|||
53
onnxruntime/core/optimizer/relu_clip_fusion.cc
Normal file
53
onnxruntime/core/optimizer/relu_clip_fusion.cc
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/graph/op.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
|
||||
// get the following Clip node before we delete the Relu node
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
|
||||
if (graph_utils::RemoveNode(graph, node)) {
|
||||
// update the following Clip node if the 'min' is < 0.f to set it to 0.f
|
||||
// this essentially fuses the Relu and Clip
|
||||
// if the Clip 'min' is >= 0.f no change is required as Relu would have set the min to 0.f
|
||||
if (graph_utils::GetNodeAttribute(next_node, "min")->f() < 0.f) {
|
||||
auto* mutable_next_node = graph.GetNode(next_node.Index());
|
||||
mutable_next_node->ClearAttribute("min");
|
||||
mutable_next_node->AddAttribute("min", 0.f);
|
||||
}
|
||||
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!graph_utils::IsSingleInSingleOutNode(node) ||
|
||||
graph.IsNodeOutputsInGraphOutputs(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If the Relu is followed by a Clip node the Relu is redundant and can be removed
|
||||
// as Clip will apply the minimum. If the Clip 'min' value is < 0 we need
|
||||
// to update it to 0 to apply what the Relu would have done. We do that in Apply.
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6}) ||
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
29
onnxruntime/core/optimizer/relu_clip_fusion.h
Normal file
29
onnxruntime/core/optimizer/relu_clip_fusion.h
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@Class FuseReluClip
|
||||
|
||||
Rewrite rule that merges a Relu operator with a following Clip operator.
|
||||
*/
|
||||
class FuseReluClip : public RewriteRule {
|
||||
public:
|
||||
FuseReluClip() noexcept : RewriteRule("FuseReluClip") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"Relu"};
|
||||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
|
|
@ -16,6 +17,7 @@
|
|||
#include "core/optimizer/conv_activation_fusion.h"
|
||||
#include "core/optimizer/matmul_add_fusion.h"
|
||||
#include "core/optimizer/gemm_activation_fusion.h"
|
||||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/ml_value.h"
|
||||
#include "core/util/math.h"
|
||||
|
|
@ -418,5 +420,72 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) {
|
|||
ASSERT_EQ(expected_values_prod, found);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, ReluClipFusion) {
|
||||
Model model("ReluClipFusion");
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
std::vector<NodeArg*> inputs;
|
||||
std::vector<NodeArg*> outputs;
|
||||
|
||||
TypeProto input_tensor_type;
|
||||
input_tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
// 3 paths in the model, each with Relu followed by Clip
|
||||
// One has a Clip with min of 0 (remove Relu)
|
||||
// One have a Clip with a min > 1 (remove Relu)
|
||||
// One has a Clip with min < 0 (remove Relu and update Clip 'min' to 0)
|
||||
auto& input0 = graph.GetOrCreateNodeArg("input_0", &input_tensor_type);
|
||||
auto& input1 = graph.GetOrCreateNodeArg("input_1", &input_tensor_type);
|
||||
auto& input2 = graph.GetOrCreateNodeArg("input_2", &input_tensor_type);
|
||||
|
||||
auto& relu0_output = graph.GetOrCreateNodeArg("relu0_output", &input_tensor_type);
|
||||
auto& relu1_output = graph.GetOrCreateNodeArg("relu1_output", &input_tensor_type);
|
||||
auto& relu2_output = graph.GetOrCreateNodeArg("relu2_output", &input_tensor_type);
|
||||
|
||||
auto& clip0_output = graph.GetOrCreateNodeArg("clip0_output", &input_tensor_type);
|
||||
auto& clip1_output = graph.GetOrCreateNodeArg("clip1_output", &input_tensor_type);
|
||||
auto& clip2_output = graph.GetOrCreateNodeArg("clip2_output", &input_tensor_type);
|
||||
|
||||
graph.AddNode("relu0", "Relu", "Relu to eliminate", {&input0}, {&relu0_output});
|
||||
graph.AddNode("relu1", "Relu", "Relu to not eliminate", {&input1}, {&relu1_output});
|
||||
graph.AddNode("relu2", "Relu", "Relu to eliminate and update 'min' of following Clip", {&input2}, {&relu2_output});
|
||||
|
||||
auto& clip0 = graph.AddNode("clip0", "Clip", "Clip with min 0", {&relu0_output}, {&clip0_output});
|
||||
clip0.AddAttribute("min", 0.f);
|
||||
clip0.AddAttribute("max", 1.f);
|
||||
|
||||
auto& clip1 = graph.AddNode("clip1", "Clip", "Clip with min 1", {&relu1_output}, {&clip1_output});
|
||||
clip1.AddAttribute("min", 1.f);
|
||||
clip1.AddAttribute("max", 1.f);
|
||||
|
||||
auto& clip2 = graph.AddNode("clip2", "Clip", "Clip with min -1", {&relu2_output}, {&clip2_output});
|
||||
clip2.AddAttribute("min", -1.f);
|
||||
clip2.AddAttribute("max", 1.f);
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 3);
|
||||
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(std::make_unique<FuseReluClip>());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 0);
|
||||
|
||||
// make sure the Clip nodes were updated to have a 'min' >= 0
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Clip") {
|
||||
auto* min = graph_utils::GetNodeAttribute(node, "min");
|
||||
ASSERT_TRUE(min->f() >= 0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue