mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Constant folding (#168)
Constant folding rewrite rule computes nodes that have only constant inputs at compile time and avoids these computations at run time.
This commit is contained in:
parent
ab734ec5a6
commit
2ae83c580c
17 changed files with 287 additions and 62 deletions
|
|
@ -69,11 +69,11 @@ class GraphTransformer {
|
|||
/**
|
||||
@class RuleBasedGraphTransformer
|
||||
|
||||
Rule based graph transformer that provides an API to register rewrite rules,
|
||||
Rule-based graph transformer that provides an API to register rewrite rules
|
||||
and an API to apply all applicable rules to a Graph.
|
||||
|
||||
Represents an IGraphTransformer determined by a set of rewrite-rules.
|
||||
The transformer will apply all the rewrite-rules iteratively as determined by the underlying rewriting-strategy.
|
||||
Represents an IGraphTransformer determined by a set of rewrite rules.
|
||||
The transformer will apply all the rewrite rules iteratively as determined by the underlying rewriting strategy.
|
||||
Several rewriting-strategies are possible when traversing the graph and applying rewrite rules,
|
||||
each with different trade offs. At the moment, we define one that performs top-down traversal of nodes.
|
||||
|
||||
|
|
@ -89,38 +89,32 @@ class RuleBasedGraphTransformer : public GraphTransformer {
|
|||
: GraphTransformer(name, desc) {}
|
||||
|
||||
/**
|
||||
Register a rewriting rule.
|
||||
|
||||
@TODO (revisit needed): Using OpSignature* here will ask that OpSignature should be stored globally.
|
||||
Otherwise, there will be multiple addresses/pointers for the same operator or function.
|
||||
To avoid this, we may use OpSignature ID as the key, which should be name_domain_version.
|
||||
We will use the string type instead of the OpSchema for now. We should probably add a version as well.
|
||||
Register a rewrite rule in this transformer.
|
||||
*/
|
||||
Status Register(const std::string& op_type, std::unique_ptr<RewriteRule> rule);
|
||||
|
||||
/** Check if the given op_type has any rules registered for it
|
||||
@returns true if there are rules registered for this op_type.*/
|
||||
bool HasRules(const std::string& op_type) const {
|
||||
return op_to_rules_.find(op_type) != op_to_rules_.cend();
|
||||
}
|
||||
Status Register(std::unique_ptr<RewriteRule> rule);
|
||||
|
||||
/**
|
||||
Gets the rewrite rules for the given op_type.
|
||||
@returns a pointer to the vector containing all the rewrite rules registered for op_type if found. nullptr
|
||||
otherwise.
|
||||
Gets the list of registered rewrite rules in this rule-based transformer.
|
||||
@returns a reference to the vector containing all the registered rewrite rules.
|
||||
*/
|
||||
const std::vector<std::unique_ptr<RewriteRule>>* GetRewriteRules(const std::string& op_type) const {
|
||||
auto entry = op_to_rules_.find(op_type);
|
||||
if (entry != op_to_rules_.cend())
|
||||
return &entry->second;
|
||||
|
||||
return nullptr;
|
||||
const std::vector<std::unique_ptr<RewriteRule>>& GetRewriteRules() const {
|
||||
return rules_;
|
||||
}
|
||||
|
||||
private:
|
||||
using RewriteRuleSet = std::unordered_map<std::string, std::vector<std::unique_ptr<RewriteRule>>>;
|
||||
protected:
|
||||
/** Apply the given set of rewrite rules on the Node of this Graph.
|
||||
@param[in] graph The Graph.
|
||||
@param[in] node The Node to apply the rules to.
|
||||
@param[in] rules The vector of RewriteRules that will be applied to the Node.
|
||||
@param[out] modified Set to indicate whether the node was modified or not.
|
||||
@param[out] deleted Set to indicate if the node was deleted.
|
||||
@returns Status indicating success or providing error information */
|
||||
common::Status ApplyRulesOnNode(Graph& graph, Node& node,
|
||||
const std::vector<std::unique_ptr<RewriteRule>>& rules,
|
||||
bool& modified, bool& deleted) const;
|
||||
|
||||
RewriteRuleSet op_to_rules_;
|
||||
private:
|
||||
std::vector<std::unique_ptr<RewriteRule>> rules_;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class RewriteRule {
|
|||
@param[out] deleted Set to indicate if the node was deleted.
|
||||
@returns Status indicating success or providing error information */
|
||||
common::Status CheckConditionAndApply(Graph& graph, Node& node, bool& modified, bool& deleted) {
|
||||
return SatisfyCondition(node) ? Apply(graph, node, modified, deleted) : Status::OK();
|
||||
return SatisfyCondition(graph, node) ? Apply(graph, node, modified, deleted) : Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -53,11 +53,14 @@ class RewriteRule {
|
|||
const std::string name_;
|
||||
const std::string desc_;
|
||||
|
||||
/** Check if the Node satisfies a condition.
|
||||
/** Check if the Node of the given Graph satisfies a condition.
|
||||
The rewrite rule is applied if the condition function returns true. This can include
|
||||
a more complex pattern matching (conditions on the ascending or descending nodes of the
|
||||
node for which this rule was triggered) or some other properties of the nodes. */
|
||||
virtual bool SatisfyCondition(const Node& node) = 0;
|
||||
virtual bool SatisfyCondition(const Graph& graph, const Node& node) = 0;
|
||||
|
||||
/** Returns true if the op type of the node is compatible with this rewrite rule. */
|
||||
virtual bool OpTypeCondition(const Node& node) = 0;
|
||||
|
||||
/**
|
||||
Apply the rewrite rule to a specific node.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -106,5 +110,44 @@ bool RemoveSingleInSingleOutNode(Graph& graph, Node& node) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool HasGraphInput(const Graph& graph, const NodeArg* input) {
|
||||
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
|
||||
return std::find(graph_inputs.begin(), graph_inputs.end(), input) != graph_inputs.end();
|
||||
}
|
||||
|
||||
bool IsConstantInputsNode(const Graph& graph, const Node& node) {
|
||||
if (node.GetInputEdgesCount() > 0) {
|
||||
return false;
|
||||
}
|
||||
const onnx::TensorProto* initializer = nullptr;
|
||||
for (const auto* input_def : node.InputDefs()) {
|
||||
// Important note: when an initializer appears in the graph's input, this input will not be considered constant,
|
||||
// because it can be overriden by the user at runtime. For constant folding to be applied, the initializer should not
|
||||
// appear in the graph's inputs (that is the only way to guarantee it will always be constant).
|
||||
if (!graph.GetInitializedTensor(input_def->Name(), initializer) || HasGraphInput(graph, input_def)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t RemoveNodeOutputEdges(Graph& graph, Node& node) {
|
||||
std::vector<std::tuple<NodeIndex, int, int>> edges_to_remove;
|
||||
for (auto it = node.OutputEdgesBegin(); it != node.OutputEdgesEnd(); ++it) {
|
||||
edges_to_remove.emplace_back(std::make_tuple(it->GetNode().Index(),
|
||||
it->GetSrcArgIndex(),
|
||||
it->GetDstArgIndex()));
|
||||
}
|
||||
for (auto& edge_to_remove : edges_to_remove) {
|
||||
graph.RemoveEdge(node.Index(),
|
||||
std::get<0>(edge_to_remove),
|
||||
std::get<1>(edge_to_remove),
|
||||
std::get<2>(edge_to_remove));
|
||||
}
|
||||
|
||||
return edges_to_remove.size();
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -40,5 +40,16 @@ bool GetRepeatedNodeAttributeValues(const Node& node,
|
|||
/** Remove the given single-input-single-output Node from the Graph. */
|
||||
bool RemoveSingleInSingleOutNode(Graph& graph, Node& node);
|
||||
|
||||
/** Returns true if the graph has the given input.*/
|
||||
bool HasGraphInput(const Graph& graph, const NodeArg* input);
|
||||
|
||||
/** Checks if the given node has only constant inputs (initializers). */
|
||||
bool IsConstantInputsNode(const Graph& graph, const Node& node);
|
||||
|
||||
/** Remove all output edges from the given Node of the Graph.
|
||||
This should probably be elevated to the Graph API eventually. */
|
||||
size_t RemoveNodeOutputEdges(Graph& graph, Node& node);
|
||||
|
||||
} // namespace utils
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
87
onnxruntime/core/optimizer/constant_folding.cc
Normal file
87
onnxruntime/core/optimizer/constant_folding.cc
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/constant_folding.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/optimizer_execution_frame.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/ml_value.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status ConstantFolding::Apply(Graph& graph, Node& node, bool& modified, bool& deleted) {
|
||||
// TODO Check if we need default transformers any more. I dont think we do...
|
||||
|
||||
// Create execution frame for executing constant nodes.
|
||||
OptimizerExecutionFrame::Info info({&node}, graph.GetAllInitializedTensors());
|
||||
|
||||
std::vector<int> fetch_mlvalue_idxs;
|
||||
for (const auto* node_out : node.OutputDefs()) {
|
||||
fetch_mlvalue_idxs.push_back(info.GetMLValueIndex(node_out->Name()));
|
||||
}
|
||||
|
||||
OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs);
|
||||
|
||||
auto* kernel = info.GetKernel(node.Index());
|
||||
OpKernelContext op_kernel_context(&frame, kernel, ::onnxruntime::logging::LoggingManager::DefaultLogger());
|
||||
|
||||
kernel->Compute(&op_kernel_context);
|
||||
|
||||
std::vector<MLValue> fetches;
|
||||
frame.GetOutputs(fetches);
|
||||
|
||||
// Go over all output node args and substitute them with the newly computed tensors, which will be
|
||||
// added to the graph as initializers.
|
||||
ORT_ENFORCE(fetches.size() == node.OutputDefs().size());
|
||||
for (int fetch_idx = 0; fetch_idx < fetches.size(); ++fetch_idx) {
|
||||
MLValue& mlvalue = fetches[fetch_idx];
|
||||
|
||||
// Build the TensorProto that corresponds to the computed MLValue and add it as initializer to the graph.
|
||||
ONNX_NAMESPACE::TensorProto out_tensorproto;
|
||||
const auto* constant_arg_out = node.OutputDefs()[fetch_idx];
|
||||
BuildTensorProtoForInitializer(mlvalue, *constant_arg_out, out_tensorproto);
|
||||
|
||||
graph.AddInitializedTensor(out_tensorproto);
|
||||
}
|
||||
|
||||
// Remove the output edges of the constant node and then remove the node itself.
|
||||
utils::RemoveNodeOutputEdges(graph, node);
|
||||
graph.RemoveNode(node.Index());
|
||||
|
||||
// The output nodes already have the right input arg, since we used the same name in the initializer.
|
||||
// We could remove unused graph initializers here, but Graph::Resolve() will take care of it.
|
||||
|
||||
modified = deleted = true;
|
||||
|
||||
return Status::OK();
|
||||
} // namespace onnxruntime
|
||||
|
||||
bool ConstantFolding::SatisfyCondition(const Graph& graph, const Node& node) {
|
||||
return OpTypeCondition(node) && utils::IsConstantInputsNode(graph, node);
|
||||
}
|
||||
|
||||
bool ConstantFolding::OpTypeCondition(const Node& node) {
|
||||
return excluded_op_types_.find(node.OpType()) == excluded_op_types_.end();
|
||||
}
|
||||
|
||||
void ConstantFolding::BuildTensorProtoForInitializer(const MLValue& mlvalue,
|
||||
const NodeArg& constant_node_arg,
|
||||
ONNX_NAMESPACE::TensorProto& tensorproto) {
|
||||
ORT_ENFORCE(mlvalue.IsTensor());
|
||||
const Tensor& out_tensor = mlvalue.Get<Tensor>();
|
||||
|
||||
// Set name, dimensions, type, and data of the TensorProto.
|
||||
tensorproto.set_name(constant_node_arg.Name());
|
||||
|
||||
for (auto& dim : out_tensor.Shape().GetDims()) {
|
||||
tensorproto.add_dims(dim);
|
||||
}
|
||||
auto tensorproto_type = constant_node_arg.TypeAsProto()->tensor_type().elem_type();
|
||||
|
||||
tensorproto.set_data_type(tensorproto_type);
|
||||
auto tensor_shape_size = out_tensor.Shape().Size();
|
||||
auto data_size = out_tensor.DataType()->Size() * tensor_shape_size;
|
||||
tensorproto.set_raw_data(out_tensor.DataRaw(out_tensor.DataType()), data_size);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
41
onnxruntime/core/optimizer/constant_folding.h
Normal file
41
onnxruntime/core/optimizer/constant_folding.h
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
#include "core/framework/ml_value.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@class ConstantFolding
|
||||
|
||||
Rewrite rule that performs constant folding to the graph.
|
||||
The rule gets applied to nodes that have only initializers as inputs. It statically computes these
|
||||
nodes and replaces their output with an initializer that corresponds to the result of the computation.
|
||||
*/
|
||||
class ConstantFolding : public RewriteRule {
|
||||
public:
|
||||
ConstantFolding() noexcept : RewriteRule("ConstantFolding", "Constant folding") {}
|
||||
|
||||
private:
|
||||
/** Constant folding will not be applied to nodes whose op_type is included in this set.
|
||||
All non-deterministic operators should be included in this set. */
|
||||
const std::unordered_set<std::string> excluded_op_types_ =
|
||||
{"RandomUniform", "RandomNormal", "RandomUniformLike", "RandomNormalLike", "Multinomial"};
|
||||
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) override;
|
||||
|
||||
bool OpTypeCondition(const Node& node) override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) override;
|
||||
|
||||
/** Create a TensorProto that has the same value as the given MLValue
|
||||
and the same type and dimensions as the given NodeArg. */
|
||||
void BuildTensorProtoForInitializer(const MLValue& mlvalue,
|
||||
const NodeArg& constant_node_arg,
|
||||
ONNX_NAMESPACE::TensorProto& tensorproto);
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -44,7 +44,6 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
std::deque<onnxruntime::NodeIndex> removed_nodes;
|
||||
for (auto index : order) {
|
||||
auto node = graph.GetNode(index);
|
||||
|
||||
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level));
|
||||
|
||||
if (!utils::IsSupportedOptypeVersionAndDomain(*node, "Conv", 1) || node->GetOutputEdgesCount() != 1) {
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ Status ConvMulFusion::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int g
|
|||
std::vector<onnxruntime::NodeIndex> removed_nodes;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
|
||||
|
||||
|
||||
if (!utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) || node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,10 @@ namespace onnxruntime {
|
|||
|
||||
namespace {
|
||||
bool IsFusableActivation(const Node& node) {
|
||||
return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6);
|
||||
return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) ||
|
||||
utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) ||
|
||||
utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) ||
|
||||
utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6);
|
||||
}
|
||||
|
||||
void HandleActivationNodeEdges(Graph& g, const Node& act, Node& fused_gemm) {
|
||||
|
|
|
|||
|
|
@ -8,13 +8,13 @@ using namespace ::onnxruntime::common;
|
|||
namespace onnxruntime {
|
||||
|
||||
Status GraphTransformer::Apply(Graph& graph, bool& modified) const {
|
||||
// the Graph should be in a good state prior this being called, so there should be no need to call Resolve here
|
||||
// The Graph should be in a good state prior this being called, so there should be no need to call Resolve here.
|
||||
// ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
|
||||
auto status = ApplyImpl(graph, modified, 0);
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
||||
// at least currently, some transformers (InsertCastTransformer and MemcpyTransformer need this to be called
|
||||
// At least currently, some transformers (InsertCastTransformer and MemcpyTransformer) need this to be called
|
||||
// after they complete to put the graph back into a valid state for the next transformer.
|
||||
if (modified) {
|
||||
status = graph.Resolve();
|
||||
|
|
@ -23,12 +23,21 @@ Status GraphTransformer::Apply(Graph& graph, bool& modified) const {
|
|||
return status;
|
||||
}
|
||||
|
||||
Status RuleBasedGraphTransformer::Register(const std::string& op_type, std::unique_ptr<RewriteRule> rule) {
|
||||
if (HasRules(op_type)) {
|
||||
op_to_rules_[op_type] = std::vector<std::unique_ptr<RewriteRule>>();
|
||||
}
|
||||
Status RuleBasedGraphTransformer::Register(std::unique_ptr<RewriteRule> rule) {
|
||||
rules_.push_back(std::move(rule));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
op_to_rules_[op_type].push_back(std::move(rule));
|
||||
Status RuleBasedGraphTransformer::ApplyRulesOnNode(Graph& graph, Node& node,
|
||||
const std::vector<std::unique_ptr<RewriteRule>>& rules,
|
||||
bool& modified, bool& deleted) const {
|
||||
for (const auto& rule : rules) {
|
||||
ORT_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, node, modified, deleted));
|
||||
if (deleted) {
|
||||
modified = true; // should be set by rewriter but in case it wasn't...
|
||||
break;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -42,19 +51,10 @@ Status TopDownRuleBasedTransformer::ApplyImpl(Graph& graph, bool& modified, int
|
|||
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
// Get the rules that should be fired for this node.
|
||||
const std::vector<std::unique_ptr<RewriteRule>>* rules = GetRewriteRules(node->OpType());
|
||||
|
||||
// Apply rewrite rules on current node, then recursively apply rules to subgraphs (if any).
|
||||
// Stop further rule application for the current node, if the node gets removed by a rule.
|
||||
bool deleted = false;
|
||||
if (rules) {
|
||||
for (const auto& rule : *rules) {
|
||||
ORT_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, *node, modified, deleted));
|
||||
if (deleted) {
|
||||
modified = true; // should be set by rewriter but in case it wasn't...
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, GetRewriteRules(), modified, deleted));
|
||||
|
||||
if (!deleted) {
|
||||
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level));
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/optimizer/constant_folding.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
// Manages a list of graph transformers. It is initialized with a list of graph
|
||||
|
|
@ -11,7 +12,7 @@ namespace onnxruntime {
|
|||
class GraphTransformerManager {
|
||||
public:
|
||||
explicit GraphTransformerManager(unsigned steps) noexcept : steps_(steps) {
|
||||
// TODO: Register default transformers.
|
||||
// Register default transformers.
|
||||
}
|
||||
|
||||
// Register a graph transformer.
|
||||
|
|
|
|||
|
|
@ -18,8 +18,12 @@ Status EliminateIdentity::Apply(Graph& graph, Node& node, bool& modified, bool&
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool EliminateIdentity::SatisfyCondition(const Node& node) {
|
||||
return utils::IsSingleInSingleOutNode(node);
|
||||
bool EliminateIdentity::SatisfyCondition(const Graph& /*graph*/, const Node& node) {
|
||||
return OpTypeCondition(node) && utils::IsSingleInSingleOutNode(node);
|
||||
}
|
||||
|
||||
bool EliminateIdentity::OpTypeCondition(const Node& node) {
|
||||
return node.OpType() == included_op_type_;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -13,7 +13,12 @@ class EliminateIdentity : public RewriteRule {
|
|||
EliminateIdentity() noexcept : RewriteRule("EliminateIdentity", "Eliminate identity node") {}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Node& node) override;
|
||||
/** Apply rule when op type is one of the following. */
|
||||
const std::string included_op_type_ = "Identity";
|
||||
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) override;
|
||||
|
||||
bool OpTypeCondition(const Node& node) override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) override;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -16,7 +16,11 @@ Status EliminateSlice::Apply(Graph& graph, Node& node, bool& modified, bool& rem
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool EliminateSlice::SatisfyCondition(const Node& node) {
|
||||
bool EliminateSlice::SatisfyCondition(const Graph& /*graph*/, const Node& node) {
|
||||
if (!OpTypeCondition(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// At the moment, we eliminate a slice operator only if it has a single input and a single output.
|
||||
if (!utils::IsSingleInSingleOutNode(node)) {
|
||||
return false;
|
||||
|
|
@ -50,4 +54,8 @@ bool EliminateSlice::SatisfyCondition(const Node& node) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool EliminateSlice::OpTypeCondition(const Node& node) {
|
||||
return node.OpType() == included_op_type_;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -13,7 +13,12 @@ class EliminateSlice : public RewriteRule {
|
|||
EliminateSlice() noexcept : RewriteRule("EliminateSlice", "Eliminate slice node") {}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Node& node) override;
|
||||
/** Apply rule when op type is one of the following. */
|
||||
const std::string included_op_type_ = "Slice";
|
||||
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) override;
|
||||
|
||||
bool OpTypeCondition(const Node& node) override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, bool& modified, bool& removed) override;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ namespace test {
|
|||
|
||||
static const std::string MODEL_FOLDER = "testdata/transform/";
|
||||
|
||||
// Return a map with the number of occurrences of each operator in the graph.
|
||||
// Returns a map with the number of occurrences of each operator in the graph.
|
||||
// Helper function to check that the graph transformations have been successfully applied.
|
||||
std::map<std::string, int> CountOpsInGraph(const Graph& graph) {
|
||||
std::map<std::string, int> op_to_count;
|
||||
|
|
@ -53,7 +53,7 @@ TEST(GraphTransformationTests, IdentityElimination) {
|
|||
|
||||
std::unique_ptr<TopDownRuleBasedTransformer> rule_transformer =
|
||||
std::make_unique<TopDownRuleBasedTransformer>("RuleTransformer1", "First rule transformer");
|
||||
rule_transformer->Register("Identity", std::make_unique<EliminateIdentity>());
|
||||
rule_transformer->Register(std::make_unique<EliminateIdentity>());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer));
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyAll(graph).IsOK());
|
||||
|
|
@ -72,7 +72,7 @@ TEST(GraphTransformationTests, SliceElimination) {
|
|||
|
||||
std::unique_ptr<TopDownRuleBasedTransformer> rule_transformer =
|
||||
std::make_unique<TopDownRuleBasedTransformer>("RuleTransformer1", "First rule transformer");
|
||||
rule_transformer->Register("Slice", std::make_unique<EliminateSlice>());
|
||||
rule_transformer->Register(std::make_unique<EliminateSlice>());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer));
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyAll(graph).IsOK());
|
||||
|
|
@ -81,6 +81,27 @@ TEST(GraphTransformationTests, SliceElimination) {
|
|||
ASSERT_TRUE(op_to_count["Slice"] == 3);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, ConstantFolding) {
|
||||
string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, model).IsOK());
|
||||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Unsqueeze"] == 2);
|
||||
|
||||
std::unique_ptr<TopDownRuleBasedTransformer> rule_transformer =
|
||||
std::make_unique<TopDownRuleBasedTransformer>("RuleTransformer1", "First rule transformer");
|
||||
|
||||
rule_transformer->Register(std::make_unique<ConstantFolding>());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer));
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyAll(graph).IsOK());
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Unsqueeze"] == 0);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) {
|
||||
string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx";
|
||||
|
||||
|
|
|
|||
Binary file not shown.
Loading…
Reference in a new issue