diff --git a/include/onnxruntime/core/optimizer/rewrite_rule.h b/include/onnxruntime/core/optimizer/rewrite_rule.h index 4401fb0f77..f481439e5f 100644 --- a/include/onnxruntime/core/optimizer/rewrite_rule.h +++ b/include/onnxruntime/core/optimizer/rewrite_rule.h @@ -35,6 +35,18 @@ If the list of op types is left empty, that rule will be triggered for every op */ class RewriteRule { public: + /** + @class RewriteRuleEffect + + Class used to indicate the effect of rule application on a graph's node. + */ + enum class RewriteRuleEffect : uint8_t { + kNone, // The rewrite rule has not modified the graph. + kUpdatedCurrentNode, // The rewrite rule updated (but did not remove) the node on which it was triggered. + kRemovedCurrentNode, // The rewrite rule removed the node on which it was triggered. + kModifiedRestOfGraph // The rewrite rule modified nodes other than the one it was triggered on. + }; + RewriteRule(const std::string& name) : name_(name) {} virtual ~RewriteRule() = default; @@ -52,11 +64,10 @@ class RewriteRule { /** Checks if the condition of the rule is satisfied, and if so applies the body of the rule. @param[in] graph The Graph. @param[in] node The Node to apply the rewrite to. - @param[out] modified Set to indicate whether the node was modified or not. - @param[out] deleted Set to indicate if the node was deleted. + @param[out] rule_effect Enum to indicate if and how the graph was modified as a result of the rule application. @returns Status indicating success or providing error information */ - common::Status CheckConditionAndApply(Graph& graph, Node& node, bool& modified, bool& deleted) { - return SatisfyCondition(graph, node) ? Apply(graph, node, modified, deleted) : Status::OK(); + common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) { + return SatisfyCondition(graph, node) ? Apply(graph, node, rule_effect) : Status::OK(); } private: @@ -72,8 +83,7 @@ class RewriteRule { /** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place. The return-value of node may be different from the input-value due to rewriting. - The value of "modified" indicates if the graph was modified or not. - The value of "deleted" indicates if the node was deleted or not. */ - virtual common::Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) = 0; + The value of "rule_effect" indicates whether and how the graph was modified by the rule. */ + virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) = 0; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/optimizer/rule_based_graph_transformer.h b/include/onnxruntime/core/optimizer/rule_based_graph_transformer.h index b20abadec3..4a3fe8159a 100644 --- a/include/onnxruntime/core/optimizer/rule_based_graph_transformer.h +++ b/include/onnxruntime/core/optimizer/rule_based_graph_transformer.h @@ -58,14 +58,16 @@ class RuleBasedGraphTransformer : public GraphTransformer { @param[in] graph The Graph. @param[in] node The Node to apply the rules to. @param[in] rules The vector of RewriteRules that will be applied to the Node. - @param[out] modified Set to indicate whether the node was modified or not. - @param[out] deleted Set to indicate if the node was deleted. + @param[out] rule_effect Enum that indicates whether and how the graph was modified as a result of + applying rules on this node. @returns Status indicating success or providing error information. */ common::Status ApplyRulesOnNode(Graph& graph, Node& node, const std::vector>& rules, - bool& modified, bool& deleted) const; + RewriteRule::RewriteRuleEffect& rule_effect) const; private: + using RuleEffect = RewriteRule::RewriteRuleEffect; + // Map that associates a node's op type with the vector of rules that are registered to be triggered for that node. std::unordered_map>> op_type_to_rules_; // Rules that will be evaluated regardless of the op type of the node. diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 2a01bdacb4..02e6b72567 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -353,19 +353,24 @@ const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const s return iter == attrs.end() ? nullptr : &iter->second; } -bool RemoveSingleInputNode(Graph& graph, Node& node) { - // Cannot remove a node with multiple output NodeArgs (multiple output edges is fine), neither - // a node whose output is also a graph output. - if (!IsSingleInSingleOutNode(node) || +bool RemoveNode(Graph& graph, Node& node) { + // Cannot remove a node with implicit inputs, with multiple output NodeArgs (multiple output edges is fine), + // or whose output is also a graph output. + if (node.ImplicitInputDefs().size() > 0 || + node.OutputDefs().size() != 1 || graph.IsNodeOutputsInGraphOutputs(node)) { return false; } - // If the single input comes from another node (initializers are not connected with edges to nodes). if (node.GetInputEdgesCount() == 1) { + // If there is a single input edge from another node (initializers are not connected with edges to nodes). return RemoveNodeWithSingleNodeIn(graph, node); - } else { + } else if (node.InputDefs().size() == 1) { + // If a single initializer is the only input. return RemoveNodeWithSingleInitializerIn(graph, node); + } else { + // No other node removal is supported, because there will be no way to connect its inputs to its outputs. + return false; } } diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index 1f3593b013..d2976f187b 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -64,10 +64,14 @@ bool GetRepeatedNodeAttributeValues(const Node& node, Status ForAllMutableSubgraphs(Graph& main_graph, std::function func); Status ForAllSubgraphs(const Graph& main_graph, std::function func); -/** Removes the given single-input Node from the Graph. The single input might be either - another node or an initializer, but not an implicit input. The node should have a single - output but can have multiple output edges. */ -bool RemoveSingleInputNode(Graph& graph, Node& node); +/** Removes the given Node from the Graph and keeps Graph consistent by rebuilding needed connections. + We support the removal of the Node if it has no implicit inputs and a single output (but it can have multiple + output edges). As for the Node's inputs, we support the following cases: + - If the Node has a single incoming node (and possibly multiple initializers), we can remove the Node and + connect its incoming node to its outgoing nodes. + - If the Node has a single initializer as input, we remove the Node and feed the initializer as input to its + output nodes. */ +bool RemoveNode(Graph& graph, Node& node); /** Removes all output edges from the given Node of the Graph. This should probably be elevated to the Graph API eventually. */ diff --git a/onnxruntime/core/optimizer/conv_add_fusion.cc b/onnxruntime/core/optimizer/conv_add_fusion.cc index 6bfa4af65f..1cb3958f63 100644 --- a/onnxruntime/core/optimizer/conv_add_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_fusion.cc @@ -9,136 +9,108 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { -Status ConvAddFusion::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const { - std::vector removed_nodes; - for (auto& node : graph.Nodes()) { - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level)); +Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified) { + auto& conv_node = node; + const auto& add_node = *conv_node.OutputNodesBegin(); + const auto& conv_inputs = conv_node.InputDefs(); + const auto& add_inputs = add_node.InputDefs(); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) || - !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || - node.GetOutputEdgesCount() != 1) { - continue; - } + const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr; + graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto); - const Node& next_node = *node.OutputNodesBegin(); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", 7) || - next_node.GetExecutionProviderType() != node.GetExecutionProviderType() || - next_node.GetInputEdgesCount() != 1 || - graph.IsNodeOutputsInGraphOutputs(next_node)) { - continue; - } + const ONNX_NAMESPACE::TensorProto* add_B_tensor_proto = nullptr; + graph.GetInitializedTensor(add_inputs[1]->Name(), add_B_tensor_proto); - auto& conv_node = node; - const Node& add_node = next_node; - - const auto& conv_inputs = conv_node.InputDefs(); - const auto& add_inputs = add_node.InputDefs(); - - const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr; - graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto); - - const ONNX_NAMESPACE::TensorProto* add_B_tensor_proto = nullptr; - graph.GetInitializedTensor(add_inputs[1]->Name(), add_B_tensor_proto); - - // Currently, fusion is only supported for float or double data type. - if (!Initializer::IsSupportedDataType(add_B_tensor_proto) || - conv_W_tensor_proto->dims_size() < 4 || - add_B_tensor_proto->dims_size() != conv_W_tensor_proto->dims_size() - 1 || - conv_W_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) { - continue; - } - - // The dimensions of add_B should be equal to 1 except first dimension. - bool flag = false; - for (int i = 1; i < add_B_tensor_proto->dims_size(); i++) { - if (add_B_tensor_proto->dims(i) != 1) { - flag = true; - break; - } - } - - if (flag) { - continue; - } - - const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr; - if (conv_inputs.size() == 3) { - graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto); - - if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) || - conv_B_tensor_proto->data_type() != add_B_tensor_proto->data_type() || - conv_B_tensor_proto->dims_size() != 1 || - conv_B_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) { - continue; - } - - auto conv_B = std::make_unique(conv_B_tensor_proto); - auto add_B = std::make_unique(add_B_tensor_proto); - - if (conv_B->size() != add_B->size()) { - continue; - } - // Calculate new value of initializers of conv node - conv_B->add(*add_B); - - // Create new initializers of conv - ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto; - conv_B->ToProto(&new_conv_B_tensor_proto); - - // Replace initializers of conv node - graph.RemoveInitializedTensor(conv_inputs[2]->Name()); - graph.AddInitializedTensor(new_conv_B_tensor_proto); - } else { - NodeArg* add_B_node_arg = graph.GetNodeArg(add_B_tensor_proto->name()); - if (add_B_node_arg == nullptr) { - continue; - } - - // Update shape of tensor proto - ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*add_B_tensor_proto); - int64_t dim = conv_W_tensor_proto->dims(0); - new_conv_B_tensor_proto.clear_dims(); - new_conv_B_tensor_proto.add_dims(dim); - - graph.RemoveInitializedTensor(add_B_tensor_proto->name()); - graph.AddInitializedTensor(new_conv_B_tensor_proto); - - // Update shape of NodeArg - TensorShapeProto shape; - shape.add_dim()->set_dim_value(dim); - add_B_node_arg->SetShape(shape); - - conv_node.MutableInputDefs().push_back(add_B_node_arg); - conv_node.MutableInputArgsCount()[2] = 1; - } - - // Replace the input of the node following add node - const NodeArg* add_output_def = add_node.OutputDefs()[0]; - NodeArg* conv_output_def = conv_node.MutableOutputDefs()[0]; - for (auto it = add_node.OutputNodesBegin(); it != add_node.OutputNodesEnd(); ++it) { - auto output_node = graph.GetNode((*it).Index()); - if (!output_node) { - return Status(ONNXRUNTIME, INVALID_ARGUMENT); - } - auto& input_defs = output_node->MutableInputDefs(); - for (auto& def : input_defs) { - if (def == add_output_def) { - def = conv_output_def; - } - } - } - - removed_nodes.push_back(add_node.Index()); + // Currently, fusion is only supported for float or double data type. + if (!Initializer::IsSupportedDataType(add_B_tensor_proto) || + conv_W_tensor_proto->dims_size() < 4 || + add_B_tensor_proto->dims_size() != conv_W_tensor_proto->dims_size() - 1 || + conv_W_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) { + return Status::OK(); } - for (auto i : removed_nodes) { - graph.RemoveNode(i); + // The dimensions of add_B should be equal to 1 except first dimension. + for (int i = 1; i < add_B_tensor_proto->dims_size(); i++) { + if (add_B_tensor_proto->dims(i) != 1) { + return Status::OK(); + } } - if (!removed_nodes.empty()) { - modified = true; + const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr; + if (conv_inputs.size() == 3) { + graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto); + + if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) || + conv_B_tensor_proto->data_type() != add_B_tensor_proto->data_type() || + conv_B_tensor_proto->dims_size() != 1 || + conv_B_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) { + return Status::OK(); + } + + auto conv_B = std::make_unique(conv_B_tensor_proto); + auto add_B = std::make_unique(add_B_tensor_proto); + + if (conv_B->size() != add_B->size()) { + return Status::OK(); + } + // Calculate new value of initializers of conv node + conv_B->add(*add_B); + + // Create new initializers of conv + ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto; + conv_B->ToProto(&new_conv_B_tensor_proto); + + // Replace initializers of conv node + graph.RemoveInitializedTensor(conv_inputs[2]->Name()); + graph.AddInitializedTensor(new_conv_B_tensor_proto); + } else { + NodeArg* add_B_node_arg = graph.GetNodeArg(add_B_tensor_proto->name()); + if (add_B_node_arg == nullptr) { + return Status::OK(); + } + + // Update shape of tensor proto + ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*add_B_tensor_proto); + int64_t dim = conv_W_tensor_proto->dims(0); + new_conv_B_tensor_proto.clear_dims(); + new_conv_B_tensor_proto.add_dims(dim); + + graph.RemoveInitializedTensor(add_B_tensor_proto->name()); + graph.AddInitializedTensor(new_conv_B_tensor_proto); + + // Update shape of NodeArg + TensorShapeProto shape; + shape.add_dim()->set_dim_value(dim); + add_B_node_arg->SetShape(shape); + + conv_node.MutableInputDefs().push_back(add_B_node_arg); + conv_node.MutableInputArgsCount()[2] = 1; + } + + // Remove Add node. + auto* add_node_to_remove = graph.GetNode(add_node.Index()); + if (graph_utils::RemoveNode(graph, *add_node_to_remove)) { + modified = RewriteRuleEffect::kModifiedRestOfGraph; } return Status::OK(); } + +bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) || + node.GetOutputEdgesCount() != 1) { + return false; + } + + const auto& next_node = *node.OutputNodesBegin(); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", 7) || + next_node.GetExecutionProviderType() != node.GetExecutionProviderType() || + next_node.GetInputEdgesCount() != 1 || + graph.IsNodeOutputsInGraphOutputs(next_node)) { + return false; + } + + return true; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/conv_add_fusion.h b/onnxruntime/core/optimizer/conv_add_fusion.h index 717e013cac..3fe4e92b5a 100644 --- a/onnxruntime/core/optimizer/conv_add_fusion.h +++ b/onnxruntime/core/optimizer/conv_add_fusion.h @@ -3,16 +3,29 @@ #pragma once -#include "core/optimizer/graph_transformer.h" +#include "core/optimizer/rewrite_rule.h" namespace onnxruntime { -class ConvAddFusion : public onnxruntime::GraphTransformer { +/** +@Class ConvAddFusion + +Rewrite rule that fuses two Conv+Add nodes to a single Conv node. + +It is attempted to be triggered only on nodes with op type "Conv". +*/ +class ConvAddFusion : public RewriteRule { public: - ConvAddFusion() noexcept : onnxruntime::GraphTransformer("ConvAddFusion") {} + ConvAddFusion() noexcept : RewriteRule("ConvAddFusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Conv"}; + } private: - Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; + bool SatisfyCondition(const Graph& graph, const Node& node) override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/conv_bn_fusion.cc b/onnxruntime/core/optimizer/conv_bn_fusion.cc index a64706ec9f..6f297ed2b0 100644 --- a/onnxruntime/core/optimizer/conv_bn_fusion.cc +++ b/onnxruntime/core/optimizer/conv_bn_fusion.cc @@ -9,184 +9,154 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { -Status ConvBNFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { - std::vector removed_nodes; - for (auto& node : graph.Nodes()) { - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level)); +Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) { + auto& conv_node = node; + const Node& bn_node = *conv_node.OutputNodesBegin(); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) || - !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || - node.GetOutputEdgesCount() != 1) { - continue; + // Get value of attribute epsilon + const onnxruntime::NodeAttributes& attributes = bn_node.GetAttributes(); + const ONNX_NAMESPACE::AttributeProto* attr = &(attributes.find("epsilon")->second); + if (attr == nullptr || attr->type() != AttributeProto_AttributeType_FLOAT) { + return Status::OK(); + } + float epsilon = static_cast(attr->f()); + + // Get initializers of BatchNormalization + const auto& bn_inputs = bn_node.InputDefs(); + const ONNX_NAMESPACE::TensorProto* bn_scale_tensor_proto = nullptr; + graph.GetInitializedTensor(bn_inputs[1]->Name(), bn_scale_tensor_proto); + + const ONNX_NAMESPACE::TensorProto* bn_B_tensor_proto = nullptr; + graph.GetInitializedTensor(bn_inputs[2]->Name(), bn_B_tensor_proto); + + const ONNX_NAMESPACE::TensorProto* bn_mean_tensor_proto = nullptr; + graph.GetInitializedTensor(bn_inputs[3]->Name(), bn_mean_tensor_proto); + + const ONNX_NAMESPACE::TensorProto* bn_var_tensor_proto = nullptr; + graph.GetInitializedTensor(bn_inputs[4]->Name(), bn_var_tensor_proto); + + const auto& conv_inputs = conv_node.InputDefs(); + const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr; + graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto); + + // Currently, fusion is only supported for float or double data type. + if (!Initializer::IsSupportedDataType(bn_scale_tensor_proto) || + !Initializer::IsSupportedDataType(bn_B_tensor_proto) || + !Initializer::IsSupportedDataType(bn_mean_tensor_proto) || + !Initializer::IsSupportedDataType(bn_var_tensor_proto) || + !Initializer::IsSupportedDataType(conv_W_tensor_proto) || + bn_scale_tensor_proto->dims_size() != 1 || + bn_B_tensor_proto->dims_size() != 1 || + bn_mean_tensor_proto->dims_size() != 1 || + bn_var_tensor_proto->dims_size() != 1 || + bn_scale_tensor_proto->dims(0) != bn_B_tensor_proto->dims(0) || + bn_B_tensor_proto->dims(0) != bn_mean_tensor_proto->dims(0) || + bn_mean_tensor_proto->dims(0) != bn_var_tensor_proto->dims(0) || + bn_scale_tensor_proto->data_type() != bn_B_tensor_proto->data_type() || + bn_B_tensor_proto->data_type() != bn_mean_tensor_proto->data_type() || + bn_mean_tensor_proto->data_type() != bn_var_tensor_proto->data_type() || + conv_W_tensor_proto->data_type() != bn_scale_tensor_proto->data_type() || + !(conv_W_tensor_proto->dims_size() > 2 && conv_W_tensor_proto->dims(0) == bn_scale_tensor_proto->dims(0))) { + return Status::OK(); + } + + auto bn_scale = std::make_unique(bn_scale_tensor_proto); + auto bn_B = std::make_unique(bn_B_tensor_proto); + auto bn_mean = std::make_unique(bn_mean_tensor_proto); + auto bn_var = std::make_unique(bn_var_tensor_proto); + auto conv_W = std::make_unique(conv_W_tensor_proto); + + const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr; + std::unique_ptr conv_B = nullptr; + if (conv_inputs.size() == 3) { + if (!graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto)) { + return Status::OK(); } - const Node& next_node = *node.OutputNodesBegin(); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BatchNormalization", 7) || - next_node.GetInputEdgesCount() != 1 || - graph.IsNodeOutputsInGraphOutputs(next_node) || - next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { - continue; + if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) || + conv_B_tensor_proto->dims_size() != 1 || + conv_B_tensor_proto->dims(0) != bn_B_tensor_proto->dims(0) || + conv_B_tensor_proto->data_type() != bn_B_tensor_proto->data_type()) { + return Status::OK(); } + conv_B = std::make_unique(conv_B_tensor_proto); + } - auto& conv_node = node; - const Node& bn_node = next_node; + // Calculate new value of initializers of conv node + bn_var->add(epsilon); + bn_var->sqrt(); + bn_scale->div(*bn_var); + conv_W->scale_by_axis(*bn_scale, 1); - // Get value of attribute group - const onnxruntime::NodeAttributes& conv_attributes = conv_node.GetAttributes(); - const ONNX_NAMESPACE::AttributeProto* group_attr = &(conv_attributes.find("group")->second); - if (group_attr != nullptr && - group_attr->type() == AttributeProto_AttributeType_INT && - group_attr->has_i() && group_attr->i() != 1) { - continue; + if (conv_inputs.size() == 3) { + conv_B->sub(*bn_mean); + conv_B->mul(*bn_scale); + conv_B->add(*bn_B); + } else { + bn_mean->mul(*bn_scale); + bn_B->sub(*bn_mean); + } + + // Create new initializers of conv + ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto); + conv_W->ToProto(&new_conv_W_tensor_proto); + + ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto; + NodeArg* bn_B_node_arg = nullptr; + if (conv_inputs.size() == 3) { + conv_B->ToProto(&new_conv_B_tensor_proto); + } else { + bn_B->ToProto(&new_conv_B_tensor_proto); + bn_B_node_arg = graph.GetNodeArg(bn_B_tensor_proto->name()); + if (bn_B_node_arg == nullptr) { + return Status::OK(); } + } - // Get value of attribute epsilon - const onnxruntime::NodeAttributes& attributes = bn_node.GetAttributes(); - const ONNX_NAMESPACE::AttributeProto* attr = &(attributes.find("epsilon")->second); - if (attr == nullptr || attr->type() != AttributeProto_AttributeType_FLOAT) { - continue; - } - float epsilon = static_cast(attr->f()); - - // Get initializers of BatchNormalization - const auto& bn_inputs = bn_node.InputDefs(); - const ONNX_NAMESPACE::TensorProto* bn_scale_tensor_proto = nullptr; - graph.GetInitializedTensor(bn_inputs[1]->Name(), bn_scale_tensor_proto); - - const ONNX_NAMESPACE::TensorProto* bn_B_tensor_proto = nullptr; - graph.GetInitializedTensor(bn_inputs[2]->Name(), bn_B_tensor_proto); - - const ONNX_NAMESPACE::TensorProto* bn_mean_tensor_proto = nullptr; - graph.GetInitializedTensor(bn_inputs[3]->Name(), bn_mean_tensor_proto); - - const ONNX_NAMESPACE::TensorProto* bn_var_tensor_proto = nullptr; - graph.GetInitializedTensor(bn_inputs[4]->Name(), bn_var_tensor_proto); - - const auto& conv_inputs = conv_node.InputDefs(); - const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr; - graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto); - - // Currently, fusion is only supported for float or double data type. - if (!Initializer::IsSupportedDataType(bn_scale_tensor_proto) || - !Initializer::IsSupportedDataType(bn_B_tensor_proto) || - !Initializer::IsSupportedDataType(bn_mean_tensor_proto) || - !Initializer::IsSupportedDataType(bn_var_tensor_proto) || - !Initializer::IsSupportedDataType(conv_W_tensor_proto) || - bn_scale_tensor_proto->dims_size() != 1 || - bn_B_tensor_proto->dims_size() != 1 || - bn_mean_tensor_proto->dims_size() != 1 || - bn_var_tensor_proto->dims_size() != 1 || - bn_scale_tensor_proto->dims(0) != bn_B_tensor_proto->dims(0) || - bn_B_tensor_proto->dims(0) != bn_mean_tensor_proto->dims(0) || - bn_mean_tensor_proto->dims(0) != bn_var_tensor_proto->dims(0) || - bn_scale_tensor_proto->data_type() != bn_B_tensor_proto->data_type() || - bn_B_tensor_proto->data_type() != bn_mean_tensor_proto->data_type() || - bn_mean_tensor_proto->data_type() != bn_var_tensor_proto->data_type() || - conv_W_tensor_proto->data_type() != bn_scale_tensor_proto->data_type() || - !(conv_W_tensor_proto->dims_size() > 2 && conv_W_tensor_proto->dims(0) == bn_scale_tensor_proto->dims(0))) { - continue; - } - - auto bn_scale = std::make_unique(bn_scale_tensor_proto); - auto bn_B = std::make_unique(bn_B_tensor_proto); - auto bn_mean = std::make_unique(bn_mean_tensor_proto); - auto bn_var = std::make_unique(bn_var_tensor_proto); - auto conv_W = std::make_unique(conv_W_tensor_proto); - - const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr; - std::unique_ptr conv_B = nullptr; - if (conv_inputs.size() == 3) { - if (!graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto)) - continue; - - if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) || - conv_B_tensor_proto->dims_size() != 1 || - conv_B_tensor_proto->dims(0) != bn_B_tensor_proto->dims(0) || - conv_B_tensor_proto->data_type() != bn_B_tensor_proto->data_type()) { - continue; - } - conv_B = std::make_unique(conv_B_tensor_proto); - } - - // Calculate new value of initializers of conv node - bn_var->add(epsilon); - bn_var->sqrt(); - bn_scale->div(*bn_var); - conv_W->scale_by_axis(*bn_scale, 1); - - if (conv_inputs.size() == 3) { - conv_B->sub(*bn_mean); - conv_B->mul(*bn_scale); - conv_B->add(*bn_B); - } else { - bn_mean->mul(*bn_scale); - bn_B->sub(*bn_mean); - } - - // Create new initializers of conv - ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto); - conv_W->ToProto(&new_conv_W_tensor_proto); - - ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto; - NodeArg* bn_B_node_arg = nullptr; - if (conv_inputs.size() == 3) { - conv_B->ToProto(&new_conv_B_tensor_proto); - } else { - bn_B->ToProto(&new_conv_B_tensor_proto); - bn_B_node_arg = graph.GetNodeArg(bn_B_tensor_proto->name()); - if (bn_B_node_arg == nullptr) { - continue; - } - } - - // Replace initializers of conv node - graph.RemoveInitializedTensor(conv_W_tensor_proto->name()); - if (conv_inputs.size() == 3) { + // Replace initializers of conv node + graph.RemoveInitializedTensor(conv_W_tensor_proto->name()); + if (conv_inputs.size() == 3) { #ifdef _MSC_VER #pragma warning(push) #pragma warning(disable : 6011) // Not deferencing null pointer. conv_B_tensor_proto is set on line 93 #endif - graph.RemoveInitializedTensor(conv_B_tensor_proto->name()); + graph.RemoveInitializedTensor(conv_B_tensor_proto->name()); #ifdef _MSC_VER #pragma warning(pop) #endif - } else { - graph.RemoveInitializedTensor(bn_B_tensor_proto->name()); - conv_node.MutableInputDefs().push_back(bn_B_node_arg); - conv_node.MutableInputArgsCount()[2] = 1; - } - graph.AddInitializedTensor(new_conv_W_tensor_proto); - graph.AddInitializedTensor(new_conv_B_tensor_proto); - - // Replace the input of the nodes following batch normalization node - const NodeArg* bn_output_def = bn_node.OutputDefs()[0]; - NodeArg* conv_output_def = conv_node.MutableOutputDefs()[0]; - for (auto it = bn_node.OutputNodesBegin(); it != bn_node.OutputNodesEnd(); ++it) { - auto output_node = graph.GetNode((*it).Index()); - if (!output_node) { - return Status(ONNXRUNTIME, INVALID_ARGUMENT); - } - - auto& input_defs = output_node->MutableInputDefs(); - for (auto& def : input_defs) { - if (def == bn_output_def) { - def = conv_output_def; - } - } - } - removed_nodes.push_back(bn_node.Index()); + } else { + graph.RemoveInitializedTensor(bn_B_tensor_proto->name()); + conv_node.MutableInputDefs().push_back(bn_B_node_arg); + conv_node.MutableInputArgsCount()[2] = 1; } + graph.AddInitializedTensor(new_conv_W_tensor_proto); + graph.AddInitializedTensor(new_conv_B_tensor_proto); - for (auto i : removed_nodes) { - graph.RemoveNode(i); - } - - if (!removed_nodes.empty()) { - modified = true; + // Remove BN node. + auto* bn_node_to_remove = graph.GetNode(bn_node.Index()); + if (graph_utils::RemoveNode(graph, *bn_node_to_remove)) { + rule_effect = RewriteRuleEffect::kModifiedRestOfGraph; } return Status::OK(); } +bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) || + node.GetOutputEdgesCount() != 1) { + return false; + } + + const auto& next_node = *node.OutputNodesBegin(); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BatchNormalization", 7) || + next_node.GetInputEdgesCount() != 1 || + graph.IsNodeOutputsInGraphOutputs(next_node) || + next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { + return false; + } + + return true; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/conv_bn_fusion.h b/onnxruntime/core/optimizer/conv_bn_fusion.h index b605956f1c..e23095bfdf 100644 --- a/onnxruntime/core/optimizer/conv_bn_fusion.h +++ b/onnxruntime/core/optimizer/conv_bn_fusion.h @@ -3,15 +3,29 @@ #pragma once -#include "core/optimizer/graph_transformer.h" +#include "core/optimizer/rewrite_rule.h" namespace onnxruntime { -class ConvBNFusion : public onnxruntime::GraphTransformer { +/** +@Class ConvBNFusion + +Rewrite rule that fuses two Conv+BN nodes to a single Conv node. + +It is attempted to be triggered only on nodes with op type "Conv". +*/ +class ConvBNFusion : public RewriteRule { public: - ConvBNFusion() noexcept : onnxruntime::GraphTransformer("ConvBNFusion") {} + ConvBNFusion() noexcept : RewriteRule("ConvBNFusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Conv"}; + } private: - Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const override; + bool SatisfyCondition(const Graph& graph, const Node& node) override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override; }; + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.cc b/onnxruntime/core/optimizer/conv_mul_fusion.cc index dc67fd90e8..fcc2515163 100644 --- a/onnxruntime/core/optimizer/conv_mul_fusion.cc +++ b/onnxruntime/core/optimizer/conv_mul_fusion.cc @@ -9,135 +9,107 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { -Status ConvMulFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { - std::vector removed_nodes; - for (auto& node : graph.Nodes()) { - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level)); - - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) || - !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || - node.GetOutputEdgesCount() != 1) { - continue; +Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) { + auto& conv_node = node; + const auto& mul_node = *conv_node.OutputNodesBegin(); + const auto& conv_inputs = conv_node.InputDefs(); + const auto& mul_inputs = mul_node.InputDefs(); + + const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr; + graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto); + + const ONNX_NAMESPACE::TensorProto* mul_B_tensor_proto = nullptr; + graph.GetInitializedTensor(mul_inputs[1]->Name(), mul_B_tensor_proto); + + if (!Initializer::IsSupportedDataType(conv_W_tensor_proto) || + !Initializer::IsSupportedDataType(mul_B_tensor_proto) || + conv_W_tensor_proto->data_type() != mul_B_tensor_proto->data_type() || + conv_W_tensor_proto->dims_size() < 4 || + !(mul_B_tensor_proto->dims_size() == 0 || + (mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1 && + conv_W_tensor_proto->dims(0) == mul_B_tensor_proto->dims(0)))) { + return Status::OK(); + } + + // The dimensions of mul_B should be equal to 1 except first dimension. + if (mul_B_tensor_proto->dims_size() != 0) { + for (int i = 1; i < mul_B_tensor_proto->dims_size(); i++) { + if (mul_B_tensor_proto->dims(i) != 1) { + return Status::OK(); + } } + } - const Node& next_node = *node.OutputNodesBegin(); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", 7) || - next_node.GetInputEdgesCount() != 1 || - graph.IsNodeOutputsInGraphOutputs(next_node) || - next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { - continue; + auto conv_W = std::make_unique(conv_W_tensor_proto); + auto mul_B = std::make_unique(mul_B_tensor_proto); + + const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr; + std::unique_ptr conv_B = nullptr; + const bool is_3d = conv_inputs.size() == 3; + if (is_3d) { + if (!graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto)) + return Status::OK(); + if (conv_B_tensor_proto == nullptr) + return Status(ONNXRUNTIME, FAIL, "Internal error in ConvMulFusion. conv_B_tensor_proto is NULL"); + if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) || + conv_B_tensor_proto->data_type() != mul_B_tensor_proto->data_type() || + conv_B_tensor_proto->dims_size() != 1 || + (mul_B_tensor_proto->dims_size() != 0 && conv_B_tensor_proto->dims(0) != mul_B_tensor_proto->dims(0))) { + return Status::OK(); } + conv_B = std::make_unique(conv_B_tensor_proto); + } - auto& conv_node = node; - const Node& mul_node = next_node; + // Calculate new value of initializers of conv node + conv_W->scale_by_axis(*mul_B, 1); - const auto& conv_inputs = conv_node.InputDefs(); - const auto& mul_inputs = mul_node.InputDefs(); - - const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr; - graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto); - - const ONNX_NAMESPACE::TensorProto* mul_B_tensor_proto = nullptr; - graph.GetInitializedTensor(mul_inputs[1]->Name(), mul_B_tensor_proto); - - if (!Initializer::IsSupportedDataType(conv_W_tensor_proto) || - !Initializer::IsSupportedDataType(mul_B_tensor_proto) || - conv_W_tensor_proto->data_type() != mul_B_tensor_proto->data_type() || - conv_W_tensor_proto->dims_size() < 4 || - !(mul_B_tensor_proto->dims_size() == 0 || - (mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1 && - conv_W_tensor_proto->dims(0) == mul_B_tensor_proto->dims(0)))) { - continue; - } - - // The dimensions of mul_B should be equal to 1 except first dimension. + if (conv_inputs.size() == 3) { if (mul_B_tensor_proto->dims_size() != 0) { - bool flag = false; - for (int i = 1; i < mul_B_tensor_proto->dims_size(); i++) { - if (mul_B_tensor_proto->dims(i) != 1) { - flag = true; - break; - } - } - - if (flag) { - continue; - } + conv_B->mul(*mul_B); + } else { + conv_B->scale_by_axis(*mul_B, 0); } - auto conv_W = std::make_unique(conv_W_tensor_proto); - auto mul_B = std::make_unique(mul_B_tensor_proto); - - const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr; - std::unique_ptr conv_B = nullptr; - const bool is_3d = conv_inputs.size() == 3; - if (is_3d) { - if (!graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto)) - continue; - if (conv_B_tensor_proto == nullptr) - return Status(ONNXRUNTIME, FAIL, "Internal error in ConvMulFusion. conv_B_tensor_proto is NULL"); - if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) || - conv_B_tensor_proto->data_type() != mul_B_tensor_proto->data_type() || - conv_B_tensor_proto->dims_size() != 1 || (mul_B_tensor_proto->dims_size() != 0 && conv_B_tensor_proto->dims(0) != mul_B_tensor_proto->dims(0))) { - continue; - } - conv_B = std::make_unique(conv_B_tensor_proto); - } - - // Calculate new value of initializers of conv node - conv_W->scale_by_axis(*mul_B, 1); - - if (conv_inputs.size() == 3) { - if (mul_B_tensor_proto->dims_size() != 0) { - conv_B->mul(*mul_B); - } else { - conv_B->scale_by_axis(*mul_B, 0); - } - } - - // Create new initializers of conv - ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto); - conv_W->ToProto(&new_conv_W_tensor_proto); - - // Replace initializers of conv node - graph.RemoveInitializedTensor(conv_inputs[1]->Name()); - graph.AddInitializedTensor(new_conv_W_tensor_proto); - - if (is_3d) { - ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*conv_B_tensor_proto); - conv_B->ToProto(&new_conv_B_tensor_proto); - graph.RemoveInitializedTensor(conv_inputs[2]->Name()); - graph.AddInitializedTensor(new_conv_B_tensor_proto); - } - - // Replace the input of the node following mul node - const NodeArg* mul_output_def = mul_node.OutputDefs()[0]; - NodeArg* conv_output_def = conv_node.MutableOutputDefs()[0]; - for (auto it = mul_node.OutputNodesBegin(); it != mul_node.OutputNodesEnd(); ++it) { - auto output_node = graph.GetNode((*it).Index()); - if (!output_node) { - return Status(ONNXRUNTIME, INVALID_ARGUMENT); - } - - auto& input_defs = output_node->MutableInputDefs(); - for (auto& def : input_defs) { - if (def == mul_output_def) { - def = conv_output_def; - } - } - } - - removed_nodes.push_back(mul_node.Index()); } - for (auto i : removed_nodes) { - graph.RemoveNode(i); + // Create new initializers of conv + ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto); + conv_W->ToProto(&new_conv_W_tensor_proto); + + // Replace initializers of conv node + graph.RemoveInitializedTensor(conv_inputs[1]->Name()); + graph.AddInitializedTensor(new_conv_W_tensor_proto); + + if (is_3d) { + ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*conv_B_tensor_proto); + conv_B->ToProto(&new_conv_B_tensor_proto); + graph.RemoveInitializedTensor(conv_inputs[2]->Name()); + graph.AddInitializedTensor(new_conv_B_tensor_proto); } - if (!removed_nodes.empty()) { - modified = true; + // Remove Mul node. + auto* mul_node_to_remove = graph.GetNode(mul_node.Index()); + if (graph_utils::RemoveNode(graph, *mul_node_to_remove)) { + rule_effect = RewriteRuleEffect::kModifiedRestOfGraph; } return Status::OK(); } +bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) || + node.GetOutputEdgesCount() != 1) { + return false; + } + + const auto& next_node = *node.OutputNodesBegin(); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", 7) || + next_node.GetInputEdgesCount() != 1 || + graph.IsNodeOutputsInGraphOutputs(next_node) || + next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { + return false; + } + + return true; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.h b/onnxruntime/core/optimizer/conv_mul_fusion.h index 6452360bd1..62a39b6245 100644 --- a/onnxruntime/core/optimizer/conv_mul_fusion.h +++ b/onnxruntime/core/optimizer/conv_mul_fusion.h @@ -2,16 +2,29 @@ // Licensed under the MIT License. #pragma once -#include "core/optimizer/graph_transformer.h" +#include "core/optimizer/rewrite_rule.h" namespace onnxruntime { -class ConvMulFusion : public onnxruntime::GraphTransformer { +/** +@Class ConvMulFusion + +Rewrite rule that fuses two Conv+Mul nodes to a single Conv node. + +It is attempted to be triggered only on nodes with op type "Conv". +*/ +class ConvMulFusion : public RewriteRule { public: - ConvMulFusion() noexcept : onnxruntime::GraphTransformer("ConvMulFusion") {} + ConvMulFusion() noexcept : RewriteRule("ConvMulFusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Conv"}; + } private: - Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override; + bool SatisfyCondition(const Graph& graph, const Node& node) override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.h b/onnxruntime/core/optimizer/graph_transformer_mgr.h index aaf3667d9b..7d330df241 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.h +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.h @@ -17,7 +17,7 @@ class GraphTransformerManager { explicit GraphTransformerManager(unsigned steps) : steps_(steps) { } - // Register a transformer with a level and compatible providers list + // Register a transformer with a level. common::Status Register(std::unique_ptr transformer, TransformerLevel level); // Apply all transformers registered for the given level on the given graph diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 23f9718004..e99eacb0b5 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -30,6 +30,9 @@ std::vector> GenerateRewriteRules(TransformerLevel break; case TransformerLevel::Level2: + rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); break; default: ORT_ENFORCE(false, "Unsupported level" + std::to_string(static_cast(level))); @@ -93,9 +96,6 @@ std::vector> GenerateTransformers(TransformerL transformers.emplace_back(std::make_unique(l2_execution_providers)); transformers.emplace_back(std::make_unique(l2_execution_providers)); #endif - transformers.emplace_back(std::make_unique()); - transformers.emplace_back(std::make_unique()); - transformers.emplace_back(std::make_unique()); } break; default: diff --git a/onnxruntime/core/optimizer/identity_elimination.cc b/onnxruntime/core/optimizer/identity_elimination.cc index 1ba4eb64c3..236b98f588 100644 --- a/onnxruntime/core/optimizer/identity_elimination.cc +++ b/onnxruntime/core/optimizer/identity_elimination.cc @@ -10,9 +10,9 @@ namespace onnxruntime { -Status EliminateIdentity::Apply(Graph& graph, Node& node, bool& modified, bool& deleted) { - if (graph_utils::RemoveSingleInputNode(graph, node)) { - modified = deleted = true; +Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) { + if (graph_utils::RemoveNode(graph, node)) { + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; } return Status::OK(); diff --git a/onnxruntime/core/optimizer/identity_elimination.h b/onnxruntime/core/optimizer/identity_elimination.h index 5f9d56dbe2..b90d2164e0 100644 --- a/onnxruntime/core/optimizer/identity_elimination.h +++ b/onnxruntime/core/optimizer/identity_elimination.h @@ -25,7 +25,7 @@ class EliminateIdentity : public RewriteRule { private: bool SatisfyCondition(const Graph& graph, const Node& node) override; - Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override; }; // namespace onnxruntime } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/rule_based_graph_transformer.cc b/onnxruntime/core/optimizer/rule_based_graph_transformer.cc index e8b65a0a21..7bafbed87b 100644 --- a/onnxruntime/core/optimizer/rule_based_graph_transformer.cc +++ b/onnxruntime/core/optimizer/rule_based_graph_transformer.cc @@ -3,6 +3,7 @@ #include "core/optimizer/rule_based_graph_transformer.h" #include "core/graph/graph_utils.h" +#include "core/optimizer/rewrite_rule.h" using namespace ::onnxruntime::common; @@ -22,11 +23,11 @@ Status RuleBasedGraphTransformer::Register(std::unique_ptr rule) { Status RuleBasedGraphTransformer::ApplyRulesOnNode(Graph& graph, Node& node, const std::vector>& rules, - bool& modified, bool& deleted) const { + RuleEffect& rule_effect) const { for (const auto& rule : rules) { - ORT_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, node, modified, deleted)); - if (deleted) { - modified = true; // should be set by rewriter but in case it wasn't... + ORT_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, node, rule_effect)); + // If the current node was removed as a result of a rule, stop rule application for that node. + if (rule_effect == RuleEffect::kRemovedCurrentNode) { break; } } @@ -39,10 +40,15 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr for (NodeIndex i : order) { auto* node = graph.GetNode(i); + // A node might not be found as it might have already been deleted from one of the rules. if (!node) { - return Status(ONNXRUNTIME, INVALID_ARGUMENT); + continue; } + // Initialize the effect of rules on this node to denote that the graph has not yet been modified + // by the rule application on the current node. + auto rule_effect = RuleEffect::kNone; + if (!graph_utils::IsSupportedProvider(*node, GetCompatibleExecutionProviders())) { continue; } @@ -50,22 +56,26 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr // First apply rewrite rules that are registered for the op type of the current node; then apply rules that are // registered to be applied regardless of the op type; then recursively apply rules to subgraphs (if any). // Stop further rule application for the current node, if the node gets removed by a rule. - bool deleted = false; const std::vector>* rules = nullptr; rules = GetRewriteRulesForOpType(node->OpType()); if (rules) { - ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, modified, deleted)); + ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect)); } - if (!deleted) { + if (rule_effect != RuleEffect::kRemovedCurrentNode) { rules = GetAnyOpRewriteRules(); if (rules) { - ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, modified, deleted)); + ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect)); } } - if (!deleted) { + // Update the modified field of the rule-based transformer. + if (rule_effect != RuleEffect::kNone) { + modified = true; + } + + if (rule_effect != RuleEffect::kRemovedCurrentNode) { ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level)); } } diff --git a/onnxruntime/core/optimizer/slice_elimination.cc b/onnxruntime/core/optimizer/slice_elimination.cc index f3ab618618..eb9a2ea368 100644 --- a/onnxruntime/core/optimizer/slice_elimination.cc +++ b/onnxruntime/core/optimizer/slice_elimination.cc @@ -8,9 +8,9 @@ namespace onnxruntime { -Status EliminateSlice::Apply(Graph& graph, Node& node, bool& modified, bool& removed) { - if (graph_utils::RemoveSingleInputNode(graph, node)) { - removed = modified = true; +Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) { + if (graph_utils::RemoveNode(graph, node)) { + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; } return Status::OK(); diff --git a/onnxruntime/core/optimizer/slice_elimination.h b/onnxruntime/core/optimizer/slice_elimination.h index b43af73209..28d689c558 100644 --- a/onnxruntime/core/optimizer/slice_elimination.h +++ b/onnxruntime/core/optimizer/slice_elimination.h @@ -25,7 +25,7 @@ class EliminateSlice : public RewriteRule { private: bool SatisfyCondition(const Graph& graph, const Node& node) override; - Status Apply(Graph& graph, Node& node, bool& modified, bool& removed) override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/unsqueeze_elimination.cc b/onnxruntime/core/optimizer/unsqueeze_elimination.cc index a53e52d650..8b58f3bd58 100644 --- a/onnxruntime/core/optimizer/unsqueeze_elimination.cc +++ b/onnxruntime/core/optimizer/unsqueeze_elimination.cc @@ -10,7 +10,7 @@ using namespace ::onnxruntime::common; namespace onnxruntime { -Status UnsqueezeElimination::Apply(Graph& graph, Node& node, bool& modified, bool& removed) { +Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) { // Get "axes" attribute. const ONNX_NAMESPACE::AttributeProto* attr = graph_utils::GetNodeAttribute(node, "axes"); if (attr == nullptr || attr->type() != AttributeProto_AttributeType_INTS) { @@ -66,8 +66,8 @@ Status UnsqueezeElimination::Apply(Graph& graph, Node& node, bool& modified, boo input_def->SetShape(shape); // Remove Unsqueeze node. - if (graph_utils::RemoveSingleInputNode(graph, node)) { - removed = modified = true; + if (graph_utils::RemoveNode(graph, node)) { + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; } return Status::OK(); diff --git a/onnxruntime/core/optimizer/unsqueeze_elimination.h b/onnxruntime/core/optimizer/unsqueeze_elimination.h index f4107fba36..e8e4dad400 100644 --- a/onnxruntime/core/optimizer/unsqueeze_elimination.h +++ b/onnxruntime/core/optimizer/unsqueeze_elimination.h @@ -25,7 +25,7 @@ class UnsqueezeElimination : public RewriteRule { private: bool SatisfyCondition(const Graph& graph, const Node& node) override; - Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) override; + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override; }; } // namespace onnxruntime diff --git a/onnxruntime/test/ir/utils_test.cc b/onnxruntime/test/ir/utils_test.cc index 0cc0d98a0e..4c36b3113e 100644 --- a/onnxruntime/test/ir/utils_test.cc +++ b/onnxruntime/test/ir/utils_test.cc @@ -207,7 +207,7 @@ static void UpdateSubgraphWhenRemovingNode(bool include_nested = false) { auto& node_to_remove = *graph.GetNode(1); const auto& if_node = *graph.GetNode(2); - bool removed = graph_utils::RemoveSingleInputNode(graph, node_to_remove); + bool removed = graph_utils::RemoveNode(graph, node_to_remove); ASSERT_TRUE(removed); // check subgraph implicit input was updated @@ -238,7 +238,7 @@ static void DontRemoveNodeIfItWillBreakSubgraph(bool test_nested = false) { auto& graph = model.MainGraph(); auto& node_to_remove = *graph.GetNode(1); - bool removed = graph_utils::RemoveSingleInputNode(graph, node_to_remove); + bool removed = graph_utils::RemoveNode(graph, node_to_remove); ASSERT_FALSE(removed); } @@ -287,7 +287,7 @@ TEST(GraphUtils, TestMultiEdgeRemovalNodes) { ASSERT_EQ(nodes[2]->GetOutputEdgesCount(), 2); // Remove id_2. This leaves id_0 with 3 output edges. id_0 is now incoming node to id_3 and id_4. - ASSERT_TRUE(graph_utils::RemoveSingleInputNode(graph, *nodes[2])); + ASSERT_TRUE(graph_utils::RemoveNode(graph, *nodes[2])); ASSERT_EQ(graph.NumberOfNodes(), 4); ASSERT_EQ(nodes[0]->GetOutputEdgesCount(), 3); ASSERT_EQ(nodes[3]->InputDefs().size(), 1); @@ -296,7 +296,7 @@ TEST(GraphUtils, TestMultiEdgeRemovalNodes) { ASSERT_TRUE(nodes[4]->InputDefs()[0]->Name() == "id_0_out"); // Remove id_0 - ASSERT_TRUE(graph_utils::RemoveSingleInputNode(graph, *nodes[0])); + ASSERT_TRUE(graph_utils::RemoveNode(graph, *nodes[0])); ASSERT_EQ(graph.NumberOfNodes(), 3); ASSERT_TRUE(nodes[1]->InputDefs()[0]->Name() == "id_0_in"); ASSERT_TRUE(nodes[3]->InputDefs()[0]->Name() == "id_0_in"); diff --git a/onnxruntime/test/optimizer/dummy_graph_transformer.h b/onnxruntime/test/optimizer/dummy_graph_transformer.h index b89f6c860b..1bff4af37f 100644 --- a/onnxruntime/test/optimizer/dummy_graph_transformer.h +++ b/onnxruntime/test/optimizer/dummy_graph_transformer.h @@ -47,7 +47,7 @@ class DummyRewriteRule : public RewriteRule { return true; } - Status Apply(Graph& /*graph*/, Node& /*node*/, bool& /*modified*/, bool& /*deleted*/) override { + Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/) override { rewrite_rule_invoked_ = true; return Status::OK(); } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index a2d96d6730..580c1f16c8 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -52,10 +52,10 @@ TEST(GraphTransformationTests, IdentityElimination) { std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Identity"] == 1); - auto rule_transformer = std::make_unique("RuleTransformer1"); - rule_transformer->Register(std::make_unique()); + auto rule_transformer_L1 = std::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(std::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1); + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); op_to_count = CountOpsInGraph(graph); @@ -70,10 +70,10 @@ TEST(GraphTransformationTests, SliceElimination) { std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Slice"] == 5); - auto rule_transformer = std::make_unique("RuleTransformer1"); - rule_transformer->Register(std::make_unique()); + auto rule_transformer_L1 = std::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(std::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1); + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); op_to_count = CountOpsInGraph(graph); @@ -129,7 +129,9 @@ TEST(GraphTransformationTests, FuseConvBNNoBias) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2); + auto rule_transformer_L2 = std::make_unique("RuleTransformerL2"); + rule_transformer_L2->Register(std::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2); ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK()); @@ -148,12 +150,15 @@ TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - auto rule_transformer = std::make_unique("RuleTransformer1"); - rule_transformer->Register(std::make_unique()); - graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1); - graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2); - graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2); - graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2); + auto rule_transformer_L1 = std::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(std::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + + auto rule_transformer_L2 = std::make_unique("RuleTransformerL2"); + rule_transformer_L2->Register(std::make_unique()); + rule_transformer_L2->Register(std::make_unique()); + rule_transformer_L2->Register(std::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2); ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK()); @@ -201,10 +206,13 @@ TEST(GraphTransformationTests, FuseConvMulNoBias) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - auto rule_transformer = std::make_unique("RuleTransformer1"); - rule_transformer->Register(std::make_unique()); - graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1); - graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2); + auto rule_transformer_L1 = std::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(std::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + + auto rule_transformer_L2 = std::make_unique("RuleTransformerL2"); + rule_transformer_L2->Register(std::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2); ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK()); @@ -222,10 +230,13 @@ TEST(GraphTransformationTests, FuseConvAddNoBias) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - auto rule_transformer = std::make_unique("RuleTransformer1"); - rule_transformer->Register(std::make_unique()); - graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1); - graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2); + auto rule_transformer_L1 = std::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(std::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + + auto rule_transformer_L2 = std::make_unique("RuleTransformerL2"); + rule_transformer_L2->Register(std::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2); ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK()); ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK()); @@ -243,8 +254,10 @@ TEST(GraphTransformationTests, FuseConvAddMul3D) { Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2); - graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2); + auto rule_transformer_L2 = std::make_unique("RuleTransformerL2"); + rule_transformer_L2->Register(std::make_unique()); + rule_transformer_L2->Register(std::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2); ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK()); @@ -315,12 +328,11 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) { std::shared_ptr p_model; ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); - std::unique_ptr ConvBNFusion_transformer = std::make_unique(); - std::unique_ptr ConvMulFusion_transformer = std::make_unique(); - std::unique_ptr ConvAddFusion_transformer = std::make_unique(); - session_object.RegisterGraphTransformer(std::move(ConvBNFusion_transformer)); - session_object.RegisterGraphTransformer(std::move(ConvMulFusion_transformer)); - session_object.RegisterGraphTransformer(std::move(ConvAddFusion_transformer)); + auto rule_transformer_L2 = std::make_unique("RuleTransformerL2"); + rule_transformer_L2->Register(std::make_unique()); + rule_transformer_L2->Register(std::make_unique()); + rule_transformer_L2->Register(std::make_unique()); + session_object.RegisterGraphTransformer(std::move(rule_transformer_L2), TransformerLevel::Level2); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -335,7 +347,8 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) { for (int i = 0; i < 9; ++i) { values_x.push_back(x_f); } - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_x, values_x, &ml_value_x); + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), + dims_x, values_x, &ml_value_x); feeds.insert(std::make_pair("X", ml_value_x)); std::vector output_names; @@ -355,7 +368,8 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) { auto& rtensor = fetches.front().Get(); TensorShape expected_shape(expected_dims_prod); ASSERT_EQ(expected_shape, rtensor.Shape()); - const std::vector found(rtensor.template Data(), rtensor.template Data() + expected_dims_prod.size()); + const std::vector found(rtensor.template Data(), + rtensor.template Data() + expected_dims_prod.size()); ASSERT_EQ(expected_values_prod, found); } diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index 8d3e2a0d99..63f7e27da1 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -38,7 +38,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { ASSERT_TRUE(transformers.size() != 0); // Transformer name match test - std::vector custom_list = {"EliminateIdentity", "ConvAddFusion", "ConvMulFusion", "abc", "def"}; + std::vector custom_list = {"EliminateIdentity", "GemmActivationFusion", "MatMulAddFusion", "abc", "def"}; transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level2, custom_list); ASSERT_TRUE(transformers.size() == 2); // validate each rule returned is present in the custom list @@ -46,7 +46,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { ASSERT_TRUE(std::find(custom_list.begin(), custom_list.end(), transformer->Name()) != custom_list.end()); } - // Transformer name no match test. When there is no match empty list is expected. + // Transformer name no-match test. When there is no match, empty list is expected. custom_list = {"EliminateIdentity"}; transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level2, custom_list); ASSERT_TRUE(transformers.size() == 0); @@ -56,7 +56,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers_CustomList) { // custom list of rules and transformers std::string l1_rule1 = "EliminateIdentity"; std::string l1_transformer = "ConstantFolding"; - std::string l2_transformer = "ConvAddFusion"; + std::string l2_transformer = "GemmActivationFusion"; std::vector custom_list = {l1_rule1, l1_transformer, l2_transformer}; auto transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level1, custom_list);