diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 48be11c496..0330c9ef5e 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -32,6 +32,7 @@ #include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/nchwc_transformer.h" #include "core/optimizer/nhwc_transformer.h" +#include "core/optimizer/noop_elimination.h" #include "core/optimizer/not_where_fusion.h" #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" @@ -69,6 +70,7 @@ std::vector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc new file mode 100644 index 0000000000..1421ea7416 --- /dev/null +++ b/onnxruntime/core/optimizer/noop_elimination.cc @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/noop_elimination.h" + +#include "core/common/logging/logging.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/op.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +/** + Eliminate no op node - handling Add op for now + Add example: + + X 0 + \ / + Add + | + Y + */ +Status NoopElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + if (graph_utils::RemoveNode(graph, node)) { + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + } + + return Status::OK(); +} + +bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { + + bool input0_is_initializer = graph_utils::IsConstantInitializer(graph, node.InputDefs()[0]->Name()); + bool input1_is_initializer = graph_utils::IsConstantInitializer(graph, node.InputDefs()[1]->Name()); + + // reject if both or neither inputs are initializers for now + if (input0_is_initializer == input1_is_initializer) { + return false; + } + + const auto* initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[input0_is_initializer ? 0 : 1]->Name()); + + // if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule, + // but it won't happen if the case is accepted, thus reject it + auto initializer_rank = initializer->dims().size(); + const auto* other_input_shape = node.InputDefs()[input0_is_initializer ? 1 : 0]->Shape(); + if (other_input_shape == nullptr || initializer_rank > other_input_shape->dim_size()) { + return false; + } + + int32_t data_type = initializer->data_type(); + Initializer add_init(*initializer, graph.ModelPath()); + if (add_init.size() > 1) { + return false; + } + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + if (*add_init.data() != 0.f) { + return false; + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + if (math::halfToFloat(add_init.data()->val) != 0.f) { + return false; + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + if (*add_init.data() != static_cast(0.f)) { + return false; + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + if (*add_init.data() != static_cast(0)) { + return false; + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + if (*add_init.data() != static_cast(0)) { + return false; + } + break; + default: + return false; + } + + // reject node output is graph output for now + if (!graph_utils::CanRemoveNode(graph, node, logger)) { + return false; + } + + return true; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/noop_elimination.h b/onnxruntime/core/optimizer/noop_elimination.h new file mode 100644 index 0000000000..7a11046277 --- /dev/null +++ b/onnxruntime/core/optimizer/noop_elimination.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +/** +@Class NoopElimination + +Rewrite rule that eliminates the no op node. +So far only Add node with 0 as one of its inputs is eliminated. +But this class could be the placeholder for other no op nodes in future. +*/ +class NoopElimination : public RewriteRule { + public: + NoopElimination() noexcept : RewriteRule("NoopElimination") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Add"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; // namespace onnxruntime + +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index d0b0a47e58..934ddf0982 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -48,6 +48,7 @@ #include "core/optimizer/matmul_integer_to_float.h" #include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/matmul_transpose_fusion.h" +#include "core/optimizer/noop_elimination.h" #include "core/optimizer/not_where_fusion.h" #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" @@ -166,6 +167,24 @@ TEST_F(GraphTransformationTests, IdentityInputIsGraphOutputNotEliminated) { ASSERT_TRUE(op_to_count["Identity"] == 1); } +TEST_F(GraphTransformationTests, NoopElimination) { + auto model_uri = MODEL_FOLDER "noop-add.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 4); + + auto rule_transformer_L1 = std::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(std::make_unique()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 1); +} + TEST_F(GraphTransformationTests, DropoutElimination) { auto model_uri = MODEL_FOLDER "dropout.onnx"; std::shared_ptr model; diff --git a/onnxruntime/test/testdata/transform/noop-add.onnx b/onnxruntime/test/testdata/transform/noop-add.onnx new file mode 100644 index 0000000000..e5793f5fda Binary files /dev/null and b/onnxruntime/test/testdata/transform/noop-add.onnx differ diff --git a/onnxruntime/test/testdata/transform/noop-add.py b/onnxruntime/test/testdata/transform/noop-add.py new file mode 100644 index 0000000000..0d1bb9801b --- /dev/null +++ b/onnxruntime/test/testdata/transform/noop-add.py @@ -0,0 +1,71 @@ +import onnx +from onnx import helper +from onnx import TensorProto, OperatorSetIdProto + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 12 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = 'com.microsoft' + +opsets.append(msdomain) +kwargs={} +kwargs['opset_imports'] = opsets + +def GenerateModel(model_name): + nodes = [ # subgraph + # float + helper.make_node("Identity", ["X1"], ["id_1"], "id_1"), + helper.make_node("Add", ["float_1", "id_1"], ["add_1"], "add_1"), + helper.make_node("Identity", ["add_1"], ["Y1"], "id_2"), + # float_16 + helper.make_node("Identity", ["X2"], ["id_3"], "id_3"), + helper.make_node("Add", ["float16_1", "id_3"], ["add_2"], "add_2"), + helper.make_node("Identity", ["add_2"], ["Y2"], "id_4"), + # int64 - flip the input 0 and 1 + helper.make_node("Identity", ["X3"], ["id_5"], "id_5"), + helper.make_node("Add", ["id_5", "int64_1"], ["add_3"], "add_3"), + helper.make_node("Identity", ["add_3"], ["Y3"], "id_6"), + # int64 + helper.make_node("Identity", ["X4"], ["id_7"], "id_7"), + helper.make_node("Add", ["id_7", "int64_2"], ["add_4"], "add_4"), + helper.make_node("Identity", ["add_4"], ["Y4"], "id_8"), + ] + + inputs = [ # inputs + helper.make_tensor_value_info('X1', TensorProto.FLOAT, ['M', 'K']), + helper.make_tensor_value_info('X2', TensorProto.FLOAT16, ['M', 'K']), + helper.make_tensor_value_info('X3', TensorProto.INT64, ['M', 'K']), + helper.make_tensor_value_info('X4', TensorProto.INT64, ['M', 'K']), + ] + + initializers = [ + helper.make_tensor('float_1', TensorProto.FLOAT, [1], [0.0]), + helper.make_tensor('float16_1', TensorProto.FLOAT16, [1], [0]), + # int64 - set tensor size to 0 + helper.make_tensor('int64_1', TensorProto.INT64, (), [0]), + # higher rank + helper.make_tensor('int64_2', TensorProto.INT64, [1,1,1], [0]), + ] + + graph = helper.make_graph( + nodes, + "NoopAdd", #name + inputs, + [ # outputs + helper.make_tensor_value_info('Y1', TensorProto.FLOAT, ['M', 'K']), + helper.make_tensor_value_info('Y2', TensorProto.FLOAT16, ['M', 'K']), + helper.make_tensor_value_info('Y3', TensorProto.INT64, ['M', 'K']), + helper.make_tensor_value_info('Y4', TensorProto.INT64, ['M', 'K', 1]), + ], + initializers) + + model = helper.make_model(graph, **kwargs) + onnx.save(model, model_name) + +if __name__ == "__main__": + GenerateModel('noop-add.onnx') \ No newline at end of file diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 9eb7b764c4..de393c762e 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -33,6 +33,7 @@ #include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/matmul_transpose_fusion.h" #include "core/optimizer/nchwc_transformer.h" +#include "core/optimizer/noop_elimination.h" #include "core/optimizer/not_where_fusion.h" #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" @@ -79,6 +80,7 @@ std::vector> GeneratePreTrainingTransformers( rule_transformer->Register(std::make_unique()); rule_transformer->Register(std::make_unique()); rule_transformer->Register(std::make_unique()); + rule_transformer->Register(std::make_unique()); rule_transformer->Register(std::make_unique()); rule_transformer->Register(std::make_unique()); rule_transformer->Register(std::make_unique());