From 2ae83c580c9cf88f0f416596cce5309d8aa5d5bb Mon Sep 17 00:00:00 2001 From: Konstantinos Karanasos Date: Wed, 13 Mar 2019 15:44:26 -0700 Subject: [PATCH] Constant folding (#168) Constant folding rewrite rule computes nodes that have only constant inputs at compile time and avoids these computations at run time. --- .../core/optimizer/graph_transformer.h | 50 +++++----- .../onnxruntime/core/optimizer/rewrite_rule.h | 9 +- onnxruntime/core/graph/graph_utils.cc | 43 +++++++++ onnxruntime/core/graph/graph_utils.h | 11 +++ .../core/optimizer/constant_folding.cc | 87 ++++++++++++++++++ onnxruntime/core/optimizer/constant_folding.h | 41 +++++++++ .../core/optimizer/conv_activation_fusion.cc | 1 - onnxruntime/core/optimizer/conv_mul_fusion.cc | 2 +- .../core/optimizer/gemm_activation_fusion.cc | 5 +- .../core/optimizer/graph_transformer.cc | 38 ++++---- .../core/optimizer/graph_transformer_mgr.h | 3 +- .../core/optimizer/identity_elimination.cc | 8 +- .../core/optimizer/identity_elimination.h | 7 +- .../core/optimizer/slice_elimination.cc | 10 +- .../core/optimizer/slice_elimination.h | 7 +- .../test/optimizer/graph_transform_test.cc | 27 +++++- .../fuse-conv-bn-mul-add-unsqueeze.onnx | Bin 643241 -> 642944 bytes 17 files changed, 287 insertions(+), 62 deletions(-) create mode 100644 onnxruntime/core/optimizer/constant_folding.cc create mode 100644 onnxruntime/core/optimizer/constant_folding.h diff --git a/include/onnxruntime/core/optimizer/graph_transformer.h b/include/onnxruntime/core/optimizer/graph_transformer.h index b019833b09..c048699796 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer.h +++ b/include/onnxruntime/core/optimizer/graph_transformer.h @@ -69,11 +69,11 @@ class GraphTransformer { /** @class RuleBasedGraphTransformer -Rule based graph transformer that provides an API to register rewrite rules, +Rule-based graph transformer that provides an API to register rewrite rules and an API to apply all applicable rules to a Graph. -Represents an IGraphTransformer determined by a set of rewrite-rules. -The transformer will apply all the rewrite-rules iteratively as determined by the underlying rewriting-strategy. +Represents an IGraphTransformer determined by a set of rewrite rules. +The transformer will apply all the rewrite rules iteratively as determined by the underlying rewriting strategy. Several rewriting-strategies are possible when traversing the graph and applying rewrite rules, each with different trade offs. At the moment, we define one that performs top-down traversal of nodes. @@ -89,38 +89,32 @@ class RuleBasedGraphTransformer : public GraphTransformer { : GraphTransformer(name, desc) {} /** - Register a rewriting rule. - - @TODO (revisit needed): Using OpSignature* here will ask that OpSignature should be stored globally. - Otherwise, there will be multiple addresses/pointers for the same operator or function. - To avoid this, we may use OpSignature ID as the key, which should be name_domain_version. - We will use the string type instead of the OpSchema for now. We should probably add a version as well. + Register a rewrite rule in this transformer. */ - Status Register(const std::string& op_type, std::unique_ptr rule); - - /** Check if the given op_type has any rules registered for it - @returns true if there are rules registered for this op_type.*/ - bool HasRules(const std::string& op_type) const { - return op_to_rules_.find(op_type) != op_to_rules_.cend(); - } + Status Register(std::unique_ptr rule); /** - Gets the rewrite rules for the given op_type. - @returns a pointer to the vector containing all the rewrite rules registered for op_type if found. nullptr - otherwise. + Gets the list of registered rewrite rules in this rule-based transformer. + @returns a reference to the vector containing all the registered rewrite rules. */ - const std::vector>* GetRewriteRules(const std::string& op_type) const { - auto entry = op_to_rules_.find(op_type); - if (entry != op_to_rules_.cend()) - return &entry->second; - - return nullptr; + const std::vector>& GetRewriteRules() const { + return rules_; } - private: - using RewriteRuleSet = std::unordered_map>>; + protected: + /** Apply the given set of rewrite rules on the Node of this Graph. + @param[in] graph The Graph. + @param[in] node The Node to apply the rules to. + @param[in] rules The vector of RewriteRules that will be applied to the Node. + @param[out] modified Set to indicate whether the node was modified or not. + @param[out] deleted Set to indicate if the node was deleted. + @returns Status indicating success or providing error information */ + common::Status ApplyRulesOnNode(Graph& graph, Node& node, + const std::vector>& rules, + bool& modified, bool& deleted) const; - RewriteRuleSet op_to_rules_; + private: + std::vector> rules_; }; /** diff --git a/include/onnxruntime/core/optimizer/rewrite_rule.h b/include/onnxruntime/core/optimizer/rewrite_rule.h index c7155de1bc..0d9c32ff51 100644 --- a/include/onnxruntime/core/optimizer/rewrite_rule.h +++ b/include/onnxruntime/core/optimizer/rewrite_rule.h @@ -44,7 +44,7 @@ class RewriteRule { @param[out] deleted Set to indicate if the node was deleted. @returns Status indicating success or providing error information */ common::Status CheckConditionAndApply(Graph& graph, Node& node, bool& modified, bool& deleted) { - return SatisfyCondition(node) ? Apply(graph, node, modified, deleted) : Status::OK(); + return SatisfyCondition(graph, node) ? Apply(graph, node, modified, deleted) : Status::OK(); } private: @@ -53,11 +53,14 @@ class RewriteRule { const std::string name_; const std::string desc_; - /** Check if the Node satisfies a condition. + /** Check if the Node of the given Graph satisfies a condition. The rewrite rule is applied if the condition function returns true. This can include a more complex pattern matching (conditions on the ascending or descending nodes of the node for which this rule was triggered) or some other properties of the nodes. */ - virtual bool SatisfyCondition(const Node& node) = 0; + virtual bool SatisfyCondition(const Graph& graph, const Node& node) = 0; + + /** Returns true if the op type of the node is compatible with this rewrite rule. */ + virtual bool OpTypeCondition(const Node& node) = 0; /** Apply the rewrite rule to a specific node. diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 3419251a1e..6585ce7a60 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -1,5 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #include "core/graph/graph_utils.h" +#include "core/graph/graph.h" +#include "core/framework/tensorprotoutils.h" namespace onnxruntime { @@ -106,5 +110,44 @@ bool RemoveSingleInSingleOutNode(Graph& graph, Node& node) { return true; } +bool HasGraphInput(const Graph& graph, const NodeArg* input) { + const std::vector& graph_inputs = graph.GetInputsIncludingInitializers(); + return std::find(graph_inputs.begin(), graph_inputs.end(), input) != graph_inputs.end(); +} + +bool IsConstantInputsNode(const Graph& graph, const Node& node) { + if (node.GetInputEdgesCount() > 0) { + return false; + } + const onnx::TensorProto* initializer = nullptr; + for (const auto* input_def : node.InputDefs()) { + // Important note: when an initializer appears in the graph's input, this input will not be considered constant, + // because it can be overriden by the user at runtime. For constant folding to be applied, the initializer should not + // appear in the graph's inputs (that is the only way to guarantee it will always be constant). + if (!graph.GetInitializedTensor(input_def->Name(), initializer) || HasGraphInput(graph, input_def)) { + return false; + } + } + return true; +} + +size_t RemoveNodeOutputEdges(Graph& graph, Node& node) { + std::vector> edges_to_remove; + for (auto it = node.OutputEdgesBegin(); it != node.OutputEdgesEnd(); ++it) { + edges_to_remove.emplace_back(std::make_tuple(it->GetNode().Index(), + it->GetSrcArgIndex(), + it->GetDstArgIndex())); + } + for (auto& edge_to_remove : edges_to_remove) { + graph.RemoveEdge(node.Index(), + std::get<0>(edge_to_remove), + std::get<1>(edge_to_remove), + std::get<2>(edge_to_remove)); + } + + return edges_to_remove.size(); +} + } // namespace utils + } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index f36ab0d763..9ab3ccddb9 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -40,5 +40,16 @@ bool GetRepeatedNodeAttributeValues(const Node& node, /** Remove the given single-input-single-output Node from the Graph. */ bool RemoveSingleInSingleOutNode(Graph& graph, Node& node); +/** Returns true if the graph has the given input.*/ +bool HasGraphInput(const Graph& graph, const NodeArg* input); + +/** Checks if the given node has only constant inputs (initializers). */ +bool IsConstantInputsNode(const Graph& graph, const Node& node); + +/** Remove all output edges from the given Node of the Graph. + This should probably be elevated to the Graph API eventually. */ +size_t RemoveNodeOutputEdges(Graph& graph, Node& node); + } // namespace utils + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc new file mode 100644 index 0000000000..b479e2a86e --- /dev/null +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/constant_folding.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/optimizer_execution_frame.h" +#include "core/framework/op_kernel.h" +#include "core/framework/ml_value.h" + +namespace onnxruntime { + +Status ConstantFolding::Apply(Graph& graph, Node& node, bool& modified, bool& deleted) { + // TODO Check if we need default transformers any more. I dont think we do... + + // Create execution frame for executing constant nodes. + OptimizerExecutionFrame::Info info({&node}, graph.GetAllInitializedTensors()); + + std::vector fetch_mlvalue_idxs; + for (const auto* node_out : node.OutputDefs()) { + fetch_mlvalue_idxs.push_back(info.GetMLValueIndex(node_out->Name())); + } + + OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs); + + auto* kernel = info.GetKernel(node.Index()); + OpKernelContext op_kernel_context(&frame, kernel, ::onnxruntime::logging::LoggingManager::DefaultLogger()); + + kernel->Compute(&op_kernel_context); + + std::vector fetches; + frame.GetOutputs(fetches); + + // Go over all output node args and substitute them with the newly computed tensors, which will be + // added to the graph as initializers. + ORT_ENFORCE(fetches.size() == node.OutputDefs().size()); + for (int fetch_idx = 0; fetch_idx < fetches.size(); ++fetch_idx) { + MLValue& mlvalue = fetches[fetch_idx]; + + // Build the TensorProto that corresponds to the computed MLValue and add it as initializer to the graph. + ONNX_NAMESPACE::TensorProto out_tensorproto; + const auto* constant_arg_out = node.OutputDefs()[fetch_idx]; + BuildTensorProtoForInitializer(mlvalue, *constant_arg_out, out_tensorproto); + + graph.AddInitializedTensor(out_tensorproto); + } + + // Remove the output edges of the constant node and then remove the node itself. + utils::RemoveNodeOutputEdges(graph, node); + graph.RemoveNode(node.Index()); + + // The output nodes already have the right input arg, since we used the same name in the initializer. + // We could remove unused graph initializers here, but Graph::Resolve() will take care of it. + + modified = deleted = true; + + return Status::OK(); +} // namespace onnxruntime + +bool ConstantFolding::SatisfyCondition(const Graph& graph, const Node& node) { + return OpTypeCondition(node) && utils::IsConstantInputsNode(graph, node); +} + +bool ConstantFolding::OpTypeCondition(const Node& node) { + return excluded_op_types_.find(node.OpType()) == excluded_op_types_.end(); +} + +void ConstantFolding::BuildTensorProtoForInitializer(const MLValue& mlvalue, + const NodeArg& constant_node_arg, + ONNX_NAMESPACE::TensorProto& tensorproto) { + ORT_ENFORCE(mlvalue.IsTensor()); + const Tensor& out_tensor = mlvalue.Get(); + + // Set name, dimensions, type, and data of the TensorProto. + tensorproto.set_name(constant_node_arg.Name()); + + for (auto& dim : out_tensor.Shape().GetDims()) { + tensorproto.add_dims(dim); + } + auto tensorproto_type = constant_node_arg.TypeAsProto()->tensor_type().elem_type(); + + tensorproto.set_data_type(tensorproto_type); + auto tensor_shape_size = out_tensor.Shape().Size(); + auto data_size = out_tensor.DataType()->Size() * tensor_shape_size; + tensorproto.set_raw_data(out_tensor.DataRaw(out_tensor.DataType()), data_size); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/constant_folding.h b/onnxruntime/core/optimizer/constant_folding.h new file mode 100644 index 0000000000..1a775aceb1 --- /dev/null +++ b/onnxruntime/core/optimizer/constant_folding.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" +#include "core/framework/ml_value.h" + +namespace onnxruntime { + +/** +@class ConstantFolding + +Rewrite rule that performs constant folding to the graph. +The rule gets applied to nodes that have only initializers as inputs. It statically computes these +nodes and replaces their output with an initializer that corresponds to the result of the computation. +*/ +class ConstantFolding : public RewriteRule { + public: + ConstantFolding() noexcept : RewriteRule("ConstantFolding", "Constant folding") {} + + private: + /** Constant folding will not be applied to nodes whose op_type is included in this set. + All non-deterministic operators should be included in this set. */ + const std::unordered_set excluded_op_types_ = + {"RandomUniform", "RandomNormal", "RandomUniformLike", "RandomNormalLike", "Multinomial"}; + + bool SatisfyCondition(const Graph& graph, const Node& node) override; + + bool OpTypeCondition(const Node& node) override; + + Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) override; + + /** Create a TensorProto that has the same value as the given MLValue + and the same type and dimensions as the given NodeArg. */ + void BuildTensorProtoForInitializer(const MLValue& mlvalue, + const NodeArg& constant_node_arg, + ONNX_NAMESPACE::TensorProto& tensorproto); +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index d867495431..073ec30e3d 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -44,7 +44,6 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l std::deque removed_nodes; for (auto index : order) { auto node = graph.GetNode(index); - ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level)); if (!utils::IsSupportedOptypeVersionAndDomain(*node, "Conv", 1) || node->GetOutputEdgesCount() != 1) { diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.cc b/onnxruntime/core/optimizer/conv_mul_fusion.cc index 3e3c2a71f8..43f68c03d7 100644 --- a/onnxruntime/core/optimizer/conv_mul_fusion.cc +++ b/onnxruntime/core/optimizer/conv_mul_fusion.cc @@ -13,7 +13,7 @@ Status ConvMulFusion::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int g std::vector removed_nodes; for (auto& node : graph.Nodes()) { ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level)); - + if (!utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) || node.GetOutputEdgesCount() != 1) { continue; } diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index 1e4bd76010..733c035bd8 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -12,7 +12,10 @@ namespace onnxruntime { namespace { bool IsFusableActivation(const Node& node) { - return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6); + return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) || + utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) || + utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) || + utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6); } void HandleActivationNodeEdges(Graph& g, const Node& act, Node& fused_gemm) { diff --git a/onnxruntime/core/optimizer/graph_transformer.cc b/onnxruntime/core/optimizer/graph_transformer.cc index 8d4e112511..700daed7ce 100644 --- a/onnxruntime/core/optimizer/graph_transformer.cc +++ b/onnxruntime/core/optimizer/graph_transformer.cc @@ -8,13 +8,13 @@ using namespace ::onnxruntime::common; namespace onnxruntime { Status GraphTransformer::Apply(Graph& graph, bool& modified) const { - // the Graph should be in a good state prior this being called, so there should be no need to call Resolve here + // The Graph should be in a good state prior this being called, so there should be no need to call Resolve here. // ORT_RETURN_IF_ERROR(graph.Resolve()); auto status = ApplyImpl(graph, modified, 0); ORT_RETURN_IF_ERROR(status); - // at least currently, some transformers (InsertCastTransformer and MemcpyTransformer need this to be called + // At least currently, some transformers (InsertCastTransformer and MemcpyTransformer) need this to be called // after they complete to put the graph back into a valid state for the next transformer. if (modified) { status = graph.Resolve(); @@ -23,12 +23,21 @@ Status GraphTransformer::Apply(Graph& graph, bool& modified) const { return status; } -Status RuleBasedGraphTransformer::Register(const std::string& op_type, std::unique_ptr rule) { - if (HasRules(op_type)) { - op_to_rules_[op_type] = std::vector>(); - } +Status RuleBasedGraphTransformer::Register(std::unique_ptr rule) { + rules_.push_back(std::move(rule)); + return Status::OK(); +} - op_to_rules_[op_type].push_back(std::move(rule)); +Status RuleBasedGraphTransformer::ApplyRulesOnNode(Graph& graph, Node& node, + const std::vector>& rules, + bool& modified, bool& deleted) const { + for (const auto& rule : rules) { + ORT_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, node, modified, deleted)); + if (deleted) { + modified = true; // should be set by rewriter but in case it wasn't... + break; + } + } return Status::OK(); } @@ -42,19 +51,10 @@ Status TopDownRuleBasedTransformer::ApplyImpl(Graph& graph, bool& modified, int return Status(ONNXRUNTIME, INVALID_ARGUMENT); } - // Get the rules that should be fired for this node. - const std::vector>* rules = GetRewriteRules(node->OpType()); - + // Apply rewrite rules on current node, then recursively apply rules to subgraphs (if any). + // Stop further rule application for the current node, if the node gets removed by a rule. bool deleted = false; - if (rules) { - for (const auto& rule : *rules) { - ORT_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, *node, modified, deleted)); - if (deleted) { - modified = true; // should be set by rewriter but in case it wasn't... - break; - } - } - } + ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, GetRewriteRules(), modified, deleted)); if (!deleted) { ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level)); diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.h b/onnxruntime/core/optimizer/graph_transformer_mgr.h index d221638eaf..74cc57e368 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.h +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.h @@ -4,6 +4,7 @@ #pragma once #include "core/optimizer/graph_transformer.h" +#include "core/optimizer/constant_folding.h" namespace onnxruntime { // Manages a list of graph transformers. It is initialized with a list of graph @@ -11,7 +12,7 @@ namespace onnxruntime { class GraphTransformerManager { public: explicit GraphTransformerManager(unsigned steps) noexcept : steps_(steps) { - // TODO: Register default transformers. + // Register default transformers. } // Register a graph transformer. diff --git a/onnxruntime/core/optimizer/identity_elimination.cc b/onnxruntime/core/optimizer/identity_elimination.cc index 2b1c168da2..1e0136b967 100644 --- a/onnxruntime/core/optimizer/identity_elimination.cc +++ b/onnxruntime/core/optimizer/identity_elimination.cc @@ -18,8 +18,12 @@ Status EliminateIdentity::Apply(Graph& graph, Node& node, bool& modified, bool& return Status::OK(); } -bool EliminateIdentity::SatisfyCondition(const Node& node) { - return utils::IsSingleInSingleOutNode(node); +bool EliminateIdentity::SatisfyCondition(const Graph& /*graph*/, const Node& node) { + return OpTypeCondition(node) && utils::IsSingleInSingleOutNode(node); +} + +bool EliminateIdentity::OpTypeCondition(const Node& node) { + return node.OpType() == included_op_type_; } } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/identity_elimination.h b/onnxruntime/core/optimizer/identity_elimination.h index 55862dd2ad..cd9a3de9cb 100644 --- a/onnxruntime/core/optimizer/identity_elimination.h +++ b/onnxruntime/core/optimizer/identity_elimination.h @@ -13,7 +13,12 @@ class EliminateIdentity : public RewriteRule { EliminateIdentity() noexcept : RewriteRule("EliminateIdentity", "Eliminate identity node") {} private: - bool SatisfyCondition(const Node& node) override; + /** Apply rule when op type is one of the following. */ + const std::string included_op_type_ = "Identity"; + + bool SatisfyCondition(const Graph& graph, const Node& node) override; + + bool OpTypeCondition(const Node& node) override; Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) override; }; diff --git a/onnxruntime/core/optimizer/slice_elimination.cc b/onnxruntime/core/optimizer/slice_elimination.cc index 13148a85ee..08be813916 100644 --- a/onnxruntime/core/optimizer/slice_elimination.cc +++ b/onnxruntime/core/optimizer/slice_elimination.cc @@ -16,7 +16,11 @@ Status EliminateSlice::Apply(Graph& graph, Node& node, bool& modified, bool& rem return Status::OK(); } -bool EliminateSlice::SatisfyCondition(const Node& node) { +bool EliminateSlice::SatisfyCondition(const Graph& /*graph*/, const Node& node) { + if (!OpTypeCondition(node)) { + return false; + } + // At the moment, we eliminate a slice operator only if it has a single input and a single output. if (!utils::IsSingleInSingleOutNode(node)) { return false; @@ -50,4 +54,8 @@ bool EliminateSlice::SatisfyCondition(const Node& node) { return true; } +bool EliminateSlice::OpTypeCondition(const Node& node) { + return node.OpType() == included_op_type_; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/slice_elimination.h b/onnxruntime/core/optimizer/slice_elimination.h index d429af6a62..1f773539f4 100644 --- a/onnxruntime/core/optimizer/slice_elimination.h +++ b/onnxruntime/core/optimizer/slice_elimination.h @@ -13,7 +13,12 @@ class EliminateSlice : public RewriteRule { EliminateSlice() noexcept : RewriteRule("EliminateSlice", "Eliminate slice node") {} private: - bool SatisfyCondition(const Node& node) override; + /** Apply rule when op type is one of the following. */ + const std::string included_op_type_ = "Slice"; + + bool SatisfyCondition(const Graph& graph, const Node& node) override; + + bool OpTypeCondition(const Node& node) override; Status Apply(Graph& graph, Node& node, bool& modified, bool& removed) override; }; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 47af27b358..603da01cff 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -32,7 +32,7 @@ namespace test { static const std::string MODEL_FOLDER = "testdata/transform/"; -// Return a map with the number of occurrences of each operator in the graph. +// Returns a map with the number of occurrences of each operator in the graph. // Helper function to check that the graph transformations have been successfully applied. std::map CountOpsInGraph(const Graph& graph) { std::map op_to_count; @@ -53,7 +53,7 @@ TEST(GraphTransformationTests, IdentityElimination) { std::unique_ptr rule_transformer = std::make_unique("RuleTransformer1", "First rule transformer"); - rule_transformer->Register("Identity", std::make_unique()); + rule_transformer->Register(std::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(std::move(rule_transformer)); ASSERT_TRUE(graph_transformation_mgr.ApplyAll(graph).IsOK()); @@ -72,7 +72,7 @@ TEST(GraphTransformationTests, SliceElimination) { std::unique_ptr rule_transformer = std::make_unique("RuleTransformer1", "First rule transformer"); - rule_transformer->Register("Slice", std::make_unique()); + rule_transformer->Register(std::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(std::move(rule_transformer)); ASSERT_TRUE(graph_transformation_mgr.ApplyAll(graph).IsOK()); @@ -81,6 +81,27 @@ TEST(GraphTransformationTests, SliceElimination) { ASSERT_TRUE(op_to_count["Slice"] == 3); } +TEST(GraphTransformationTests, ConstantFolding) { + string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; + std::shared_ptr model; + ASSERT_TRUE(Model::Load(model_uri, model).IsOK()); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Unsqueeze"] == 2); + + std::unique_ptr rule_transformer = + std::make_unique("RuleTransformer1", "First rule transformer"); + + rule_transformer->Register(std::make_unique()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + + graph_transformation_mgr.Register(std::move(rule_transformer)); + ASSERT_TRUE(graph_transformation_mgr.ApplyAll(graph).IsOK()); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Unsqueeze"] == 0); +} + TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) { string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-conv-bn-mul-add-unsqueeze.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-conv-bn-mul-add-unsqueeze.onnx index ac86e07847307b1e831e39e1ff306b5b5bbe6e2d..668c7932e4db5cb510acb4a0545547014b7437c8 100644 GIT binary patch delta 121 zcmZ3vP`zQkx-18a)wg--T;W{2DTyVC@wusqdGQAE23%|)Zli3gEMu!IQ>!dSx&+8)0;1FDj6hcaq*;;7N@4f7o`@L=9CCYaY=A63W;zraWHbR Oa5OX+B3MpL0_*@;Egs1L delta 376 zcmZpeufB4jx-18?)%=C(T;W`7DTyVC@djMHU^+K7F)!YrQMOf