diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 84fffe23bf..b07949343c 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -31,6 +31,7 @@ #include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/nchwc_transformer.h" #include "core/optimizer/nhwc_transformer.h" +#include "core/optimizer/not_where_fusion.h" #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" @@ -68,6 +69,7 @@ std::vector> GenerateRewriteRules( rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); + rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); rules.push_back(onnxruntime::make_unique()); diff --git a/onnxruntime/core/optimizer/not_where_fusion.cc b/onnxruntime/core/optimizer/not_where_fusion.cc new file mode 100644 index 0000000000..e72ff4fc9a --- /dev/null +++ b/onnxruntime/core/optimizer/not_where_fusion.cc @@ -0,0 +1,137 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/not_where_fusion.h" + +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; +namespace onnxruntime { + +/** +Transform that fuses two Not -> Where nodes to a single Where node +with the where inputs 1 and 2 flipped. +Condition -> Not -> Where -> + value0-| | + value1----| + +Condition -> Where -> + value1-| | + value0----| + +It also fuses when not node has multiple where consumer nodes: + +Condition -> Not -> Where -> + | v0-| | + | v1----| + |----> Where -> + v0-| | + v1----| + +Condition -> Where -> + | v1-| | + | v0----| + |----> Where -> + v1-| | + v0----| + */ +bool NotWhereFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Where", {9})) { + return false; + } + + const Node* p_not_node = graph_utils::GetInputNode(node, 0); + if (p_not_node == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*p_not_node, "Not", {1}) || + // Make sure the two nodes do not span execution providers. + p_not_node->GetExecutionProviderType() != node.GetExecutionProviderType()) { + return false; + } + + if (p_not_node->GetOutputEdgesCount() > 1) { + // all consumers of not must be where + for (auto it = p_not_node->OutputNodesBegin(); it != p_not_node->OutputNodesEnd(); ++it) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(*it, "Where", {9})) { + return false; + } + } + } + + if (!graph_utils::CanRemoveNode(graph, *p_not_node, logger)) { + return false; + } + + return true; +} + +Status NotWhereFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + const Node* p_not_node = graph_utils::GetInputNode(node, 0); + + auto& not_node = *graph.GetNode(p_not_node->Index()); // get mutable next node + NodeArg* not_input = not_node.MutableInputDefs()[0]; + + // get all node ids of consumer where nodes + std::vector where_node_ids; + for (auto it = p_not_node->OutputNodesBegin(); it != p_not_node->OutputNodesEnd(); ++it) { + where_node_ids.push_back(it->Index()); + } + + // Move input egdes from not_node to all where_node + const Node* not_input_node = graph_utils::GetInputNode(not_node, 0); + if (not_input_node) { + Node& replacement = *graph.GetNode(not_input_node->Index()); + int replacement_output_idx = graph_utils::GetNodeOutputIndexFromOutputName(replacement, not_input->Name()); + // Replace inputs of all downstream where nodes with input of not_node by + // removing not's output edges, updating input names of not's consumers, + // and adding the edges from not's input to not's consumers. + graph_utils::ReplaceDownstreamNodeInput(graph, not_node, 0, replacement, replacement_output_idx); + } else { // not's input is graph input/initializer. Remove the output egdes for not_node + graph_utils::RemoveNodeOutputEdges(graph, not_node); + } + + for (auto it = where_node_ids.begin(); it != where_node_ids.end(); ++it) { + auto& where_node = *graph.GetNode(*it); + + std::vector where_inputs = where_node.MutableInputDefs(); + + if (!not_input_node) { // not's input is graph input/initializer. + graph_utils::ReplaceNodeInput(where_node, 0, *not_input); + } + + const Node* where_input1_node = graph_utils::GetInputNode(where_node, 1); + const Node* where_input2_node = graph_utils::GetInputNode(where_node, 2); + + int output1_idx = -1, output2_idx = -1; + if (where_input1_node) { + output1_idx = graph_utils::GetNodeOutputIndexFromOutputName(*where_input1_node, where_inputs[1]->Name()); + graph.RemoveEdge(where_input1_node->Index(), where_node.Index(), output1_idx, 1); + } + + if (where_input2_node) { + output2_idx = graph_utils::GetNodeOutputIndexFromOutputName(*where_input2_node, where_inputs[2]->Name()); + graph.RemoveEdge(where_input2_node->Index(), where_node.Index(), output2_idx, 2); + } + + graph_utils::ReplaceNodeInput(where_node, 1, *where_inputs[2]); + graph_utils::ReplaceNodeInput(where_node, 2, *where_inputs[1]); + + if (where_input1_node) { + graph.AddEdge(where_input1_node->Index(), where_node.Index(), output1_idx, 2); + } + + if (where_input2_node) { + graph.AddEdge(where_input2_node->Index(), where_node.Index(), output2_idx, 1); + } + } + + // remove not_node + graph.RemoveNode(not_node.Index()); + + rule_effect = RewriteRuleEffect::kModifiedRestOfGraph; + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/not_where_fusion.h b/onnxruntime/core/optimizer/not_where_fusion.h new file mode 100644 index 0000000000..c99758acb7 --- /dev/null +++ b/onnxruntime/core/optimizer/not_where_fusion.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { +/** +@Class NotWhereFusion + +Rewrite rule that fuses two Not -> Where nodes to a single Where node +with the where inputs 1 and 2 flipped. +Condition -> Not -> Where -> + value0-| | + value1----| + +Condition -> Where -> + value1-| | + value0----| +*/ +class NotWhereFusion : public RewriteRule { + public: + NotWhereFusion() noexcept : RewriteRule("NotWhereFusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Where"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index b5345da5dc..e31d819d89 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -46,6 +46,7 @@ #include "core/optimizer/matmul_integer_to_float.h" #include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/matmul_transpose_fusion.h" +#include "core/optimizer/not_where_fusion.h" #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" @@ -548,6 +549,26 @@ TEST_F(GraphTransformationTests, DivMulFusion) { ASSERT_TRUE(op_to_count["Mul"] == 2); } +TEST_F(GraphTransformationTests, NotWhereFusion) { + auto model_uri = MODEL_FOLDER "fusion/not_where.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Not"] == 4); + ASSERT_TRUE(op_to_count["Where"] == 5); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = onnxruntime::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Where"] == 5); + ASSERT_TRUE(op_to_count["Not"] == 1); // can't remove Not if it is graph output/ has consumer that's not where +} + #if defined(USE_CUDA) && !defined(DISABLE_CONTRIB_OPS) TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) { auto model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/not_where.onnx b/onnxruntime/test/testdata/transform/fusion/not_where.onnx new file mode 100644 index 0000000000..05945fdeb9 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/not_where.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/not_where.py b/onnxruntime/test/testdata/transform/fusion/not_where.py new file mode 100644 index 0000000000..ba3b3a5fef --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/not_where.py @@ -0,0 +1,63 @@ +import onnx +from onnx import helper +from onnx import TensorProto, OperatorSetIdProto +from enum import Enum + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 12 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = 'com.microsoft' + +opsets.append(msdomain) +kwargs={} +kwargs['opset_imports'] = opsets + +def GenerateModel(model_name): + nodes = [ # subgraph + # float + helper.make_node("Not", ["X"], ["not_X_1"], "not_1"), + helper.make_node("Where", ["not_X_1", "v0", "v1"], ["Y1"], "where_1"), + helper.make_node("Not", ["not_X_1"], ["x"], "not_2"), + helper.make_node("Identity", ["v0"], ["v0_edge"], "identity_v0"), + helper.make_node("Identity", ["v1"], ["v1_edge"], "identity_v1"), + helper.make_node("Where", ["x", "v0_edge", "v1_edge"], ["Y2"], "where_2"), + helper.make_node("Not", ["X"], ["not_X_2"], "not_3"), + helper.make_node("Where", ["not_X_2", "v0", "v1"], ["Y3"], "where_3"), + helper.make_node("Not", ["X"], ["not_X_3"], "not_4"), + helper.make_node("Where", ["not_X_3", "v0", "v1"], ["Y4"], "where_4"), + helper.make_node("Where", ["not_X_3", "v0", "v1"], ["Y5"], "where_5"), + ] + + inputs = [ # inputs + helper.make_tensor_value_info('X', TensorProto.BOOL, ['M', 'K']), + ] + + initializers = [ + helper.make_tensor('v0', TensorProto.FLOAT, [1], [1.0]), + helper.make_tensor('v1', TensorProto.FLOAT, [1], [-1.0]), + ] + + graph = helper.make_graph( + nodes, + "NotWhere", #name + inputs, + [ # outputs + helper.make_tensor_value_info('not_X_2', TensorProto.BOOL, ['M', 'K']), + helper.make_tensor_value_info('Y1', TensorProto.FLOAT, ['M', 'K']), + helper.make_tensor_value_info('Y2', TensorProto.FLOAT, ['M', 'K']), + helper.make_tensor_value_info('Y3', TensorProto.FLOAT, ['M', 'K']), + helper.make_tensor_value_info('Y4', TensorProto.FLOAT, ['M', 'K']), + helper.make_tensor_value_info('Y5', TensorProto.FLOAT, ['M', 'K']), + ], + initializers) + + model = helper.make_model(graph, **kwargs) + onnx.save(model, model_name) + +if __name__ == "__main__": + GenerateModel('not_where.onnx') \ No newline at end of file diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 535a2eca74..82a6c52925 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -31,6 +31,7 @@ #include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/matmul_transpose_fusion.h" #include "core/optimizer/nchwc_transformer.h" +#include "core/optimizer/not_where_fusion.h" #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" @@ -76,6 +77,7 @@ std::vector> GeneratePreTrainingTransformers( rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); + rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique());