Eliminate no op node - add 0 (#7798)

* eliminate add 0

* typo

* rank check

* fix build

Co-authored-by: Ethan Tao <ettao@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
ytaous 2021-05-25 13:01:34 -07:00 committed by GitHub
parent 9241d76396
commit ff655175ff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 221 additions and 0 deletions

View file

@ -32,6 +32,7 @@
#include "core/optimizer/matmul_scale_fusion.h"
#include "core/optimizer/nchwc_transformer.h"
#include "core/optimizer/nhwc_transformer.h"
#include "core/optimizer/noop_elimination.h"
#include "core/optimizer/not_where_fusion.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
@ -69,6 +70,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
rules.push_back(std::make_unique<EliminateDropout>());
rules.push_back(std::make_unique<ExpandElimination>());
rules.push_back(std::make_unique<CastElimination>());
rules.push_back(std::make_unique<NoopElimination>());
rules.push_back(std::make_unique<DivMulFusion>());
rules.push_back(std::make_unique<FuseReluClip>());
rules.push_back(std::make_unique<GemmTransposeFusion>());

View file

@ -0,0 +1,96 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/noop_elimination.h"
#include "core/common/logging/logging.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/op.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
/**
Eliminate no op node - handling Add op for now
Add example:
X 0
\ /
Add
|
Y
*/
Status NoopElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
return Status::OK();
}
bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
bool input0_is_initializer = graph_utils::IsConstantInitializer(graph, node.InputDefs()[0]->Name());
bool input1_is_initializer = graph_utils::IsConstantInitializer(graph, node.InputDefs()[1]->Name());
// reject if both or neither inputs are initializers for now
if (input0_is_initializer == input1_is_initializer) {
return false;
}
const auto* initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[input0_is_initializer ? 0 : 1]->Name());
// if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule,
// but it won't happen if the case is accepted, thus reject it
auto initializer_rank = initializer->dims().size();
const auto* other_input_shape = node.InputDefs()[input0_is_initializer ? 1 : 0]->Shape();
if (other_input_shape == nullptr || initializer_rank > other_input_shape->dim_size()) {
return false;
}
int32_t data_type = initializer->data_type();
Initializer add_init(*initializer, graph.ModelPath());
if (add_init.size() > 1) {
return false;
}
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
if (*add_init.data<float>() != 0.f) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
if (math::halfToFloat(add_init.data<MLFloat16>()->val) != 0.f) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
if (*add_init.data<double>() != static_cast<double>(0.f)) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
if (*add_init.data<int32_t>() != static_cast<int32_t>(0)) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
if (*add_init.data<int64_t>() != static_cast<int64_t>(0)) {
return false;
}
break;
default:
return false;
}
// reject node output is graph output for now
if (!graph_utils::CanRemoveNode(graph, node, logger)) {
return false;
}
return true;
}
} // namespace onnxruntime

View file

@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
/**
@Class NoopElimination
Rewrite rule that eliminates the no op node.
So far only Add node with 0 as one of its inputs is eliminated.
But this class could be the placeholder for other no op nodes in future.
*/
class NoopElimination : public RewriteRule {
public:
NoopElimination() noexcept : RewriteRule("NoopElimination") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Add"};
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
}; // namespace onnxruntime
} // namespace onnxruntime

View file

@ -48,6 +48,7 @@
#include "core/optimizer/matmul_integer_to_float.h"
#include "core/optimizer/matmul_scale_fusion.h"
#include "core/optimizer/matmul_transpose_fusion.h"
#include "core/optimizer/noop_elimination.h"
#include "core/optimizer/not_where_fusion.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
@ -166,6 +167,24 @@ TEST_F(GraphTransformationTests, IdentityInputIsGraphOutputNotEliminated) {
ASSERT_TRUE(op_to_count["Identity"] == 1);
}
TEST_F(GraphTransformationTests, NoopElimination) {
auto model_uri = MODEL_FOLDER "noop-add.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 4);
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(std::make_unique<NoopElimination>());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 1);
}
TEST_F(GraphTransformationTests, DropoutElimination) {
auto model_uri = MODEL_FOLDER "dropout.onnx";
std::shared_ptr<Model> model;

Binary file not shown.

View file

@ -0,0 +1,71 @@
import onnx
from onnx import helper
from onnx import TensorProto, OperatorSetIdProto
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = 'com.microsoft'
opsets.append(msdomain)
kwargs={}
kwargs['opset_imports'] = opsets
def GenerateModel(model_name):
nodes = [ # subgraph
# float
helper.make_node("Identity", ["X1"], ["id_1"], "id_1"),
helper.make_node("Add", ["float_1", "id_1"], ["add_1"], "add_1"),
helper.make_node("Identity", ["add_1"], ["Y1"], "id_2"),
# float_16
helper.make_node("Identity", ["X2"], ["id_3"], "id_3"),
helper.make_node("Add", ["float16_1", "id_3"], ["add_2"], "add_2"),
helper.make_node("Identity", ["add_2"], ["Y2"], "id_4"),
# int64 - flip the input 0 and 1
helper.make_node("Identity", ["X3"], ["id_5"], "id_5"),
helper.make_node("Add", ["id_5", "int64_1"], ["add_3"], "add_3"),
helper.make_node("Identity", ["add_3"], ["Y3"], "id_6"),
# int64
helper.make_node("Identity", ["X4"], ["id_7"], "id_7"),
helper.make_node("Add", ["id_7", "int64_2"], ["add_4"], "add_4"),
helper.make_node("Identity", ["add_4"], ["Y4"], "id_8"),
]
inputs = [ # inputs
helper.make_tensor_value_info('X1', TensorProto.FLOAT, ['M', 'K']),
helper.make_tensor_value_info('X2', TensorProto.FLOAT16, ['M', 'K']),
helper.make_tensor_value_info('X3', TensorProto.INT64, ['M', 'K']),
helper.make_tensor_value_info('X4', TensorProto.INT64, ['M', 'K']),
]
initializers = [
helper.make_tensor('float_1', TensorProto.FLOAT, [1], [0.0]),
helper.make_tensor('float16_1', TensorProto.FLOAT16, [1], [0]),
# int64 - set tensor size to 0
helper.make_tensor('int64_1', TensorProto.INT64, (), [0]),
# higher rank
helper.make_tensor('int64_2', TensorProto.INT64, [1,1,1], [0]),
]
graph = helper.make_graph(
nodes,
"NoopAdd", #name
inputs,
[ # outputs
helper.make_tensor_value_info('Y1', TensorProto.FLOAT, ['M', 'K']),
helper.make_tensor_value_info('Y2', TensorProto.FLOAT16, ['M', 'K']),
helper.make_tensor_value_info('Y3', TensorProto.INT64, ['M', 'K']),
helper.make_tensor_value_info('Y4', TensorProto.INT64, ['M', 'K', 1]),
],
initializers)
model = helper.make_model(graph, **kwargs)
onnx.save(model, model_name)
if __name__ == "__main__":
GenerateModel('noop-add.onnx')

View file

@ -33,6 +33,7 @@
#include "core/optimizer/matmul_scale_fusion.h"
#include "core/optimizer/matmul_transpose_fusion.h"
#include "core/optimizer/nchwc_transformer.h"
#include "core/optimizer/noop_elimination.h"
#include "core/optimizer/not_where_fusion.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
@ -79,6 +80,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
rule_transformer->Register(std::make_unique<UnsqueezeElimination>());
rule_transformer->Register(std::make_unique<ExpandElimination>());
rule_transformer->Register(std::make_unique<CastElimination>());
rule_transformer->Register(std::make_unique<NoopElimination>());
rule_transformer->Register(std::make_unique<DivMulFusion>());
rule_transformer->Register(std::make_unique<EliminateDropout>());
rule_transformer->Register(std::make_unique<GemmTransposeFusion>());