Constant folding (#168)

Constant folding rewrite rule computes nodes that have only constant inputs at compile time and avoids these computations at run time.
This commit is contained in:
Konstantinos Karanasos 2019-03-13 15:44:26 -07:00 committed by GitHub
parent ab734ec5a6
commit 2ae83c580c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 287 additions and 62 deletions

View file

@ -69,11 +69,11 @@ class GraphTransformer {
/**
@class RuleBasedGraphTransformer
Rule based graph transformer that provides an API to register rewrite rules,
Rule-based graph transformer that provides an API to register rewrite rules
and an API to apply all applicable rules to a Graph.
Represents an IGraphTransformer determined by a set of rewrite-rules.
The transformer will apply all the rewrite-rules iteratively as determined by the underlying rewriting-strategy.
Represents an IGraphTransformer determined by a set of rewrite rules.
The transformer will apply all the rewrite rules iteratively as determined by the underlying rewriting strategy.
Several rewriting-strategies are possible when traversing the graph and applying rewrite rules,
each with different trade offs. At the moment, we define one that performs top-down traversal of nodes.
@ -89,38 +89,32 @@ class RuleBasedGraphTransformer : public GraphTransformer {
: GraphTransformer(name, desc) {}
/**
Register a rewriting rule.
@TODO (revisit needed): Using OpSignature* here will ask that OpSignature should be stored globally.
Otherwise, there will be multiple addresses/pointers for the same operator or function.
To avoid this, we may use OpSignature ID as the key, which should be name_domain_version.
We will use the string type instead of the OpSchema for now. We should probably add a version as well.
Register a rewrite rule in this transformer.
*/
Status Register(const std::string& op_type, std::unique_ptr<RewriteRule> rule);
/** Check if the given op_type has any rules registered for it
@returns true if there are rules registered for this op_type.*/
bool HasRules(const std::string& op_type) const {
return op_to_rules_.find(op_type) != op_to_rules_.cend();
}
Status Register(std::unique_ptr<RewriteRule> rule);
/**
Gets the rewrite rules for the given op_type.
@returns a pointer to the vector containing all the rewrite rules registered for op_type if found. nullptr
otherwise.
Gets the list of registered rewrite rules in this rule-based transformer.
@returns a reference to the vector containing all the registered rewrite rules.
*/
const std::vector<std::unique_ptr<RewriteRule>>* GetRewriteRules(const std::string& op_type) const {
auto entry = op_to_rules_.find(op_type);
if (entry != op_to_rules_.cend())
return &entry->second;
return nullptr;
const std::vector<std::unique_ptr<RewriteRule>>& GetRewriteRules() const {
return rules_;
}
private:
using RewriteRuleSet = std::unordered_map<std::string, std::vector<std::unique_ptr<RewriteRule>>>;
protected:
/** Apply the given set of rewrite rules on the Node of this Graph.
@param[in] graph The Graph.
@param[in] node The Node to apply the rules to.
@param[in] rules The vector of RewriteRules that will be applied to the Node.
@param[out] modified Set to indicate whether the node was modified or not.
@param[out] deleted Set to indicate if the node was deleted.
@returns Status indicating success or providing error information */
common::Status ApplyRulesOnNode(Graph& graph, Node& node,
const std::vector<std::unique_ptr<RewriteRule>>& rules,
bool& modified, bool& deleted) const;
RewriteRuleSet op_to_rules_;
private:
std::vector<std::unique_ptr<RewriteRule>> rules_;
};
/**

View file

@ -44,7 +44,7 @@ class RewriteRule {
@param[out] deleted Set to indicate if the node was deleted.
@returns Status indicating success or providing error information */
common::Status CheckConditionAndApply(Graph& graph, Node& node, bool& modified, bool& deleted) {
return SatisfyCondition(node) ? Apply(graph, node, modified, deleted) : Status::OK();
return SatisfyCondition(graph, node) ? Apply(graph, node, modified, deleted) : Status::OK();
}
private:
@ -53,11 +53,14 @@ class RewriteRule {
const std::string name_;
const std::string desc_;
/** Check if the Node satisfies a condition.
/** Check if the Node of the given Graph satisfies a condition.
The rewrite rule is applied if the condition function returns true. This can include
a more complex pattern matching (conditions on the ascending or descending nodes of the
node for which this rule was triggered) or some other properties of the nodes. */
virtual bool SatisfyCondition(const Node& node) = 0;
virtual bool SatisfyCondition(const Graph& graph, const Node& node) = 0;
/** Returns true if the op type of the node is compatible with this rewrite rule. */
virtual bool OpTypeCondition(const Node& node) = 0;
/**
Apply the rewrite rule to a specific node.

View file

@ -1,5 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/graph_utils.h"
#include "core/graph/graph.h"
#include "core/framework/tensorprotoutils.h"
namespace onnxruntime {
@ -106,5 +110,44 @@ bool RemoveSingleInSingleOutNode(Graph& graph, Node& node) {
return true;
}
bool HasGraphInput(const Graph& graph, const NodeArg* input) {
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
return std::find(graph_inputs.begin(), graph_inputs.end(), input) != graph_inputs.end();
}
bool IsConstantInputsNode(const Graph& graph, const Node& node) {
if (node.GetInputEdgesCount() > 0) {
return false;
}
const onnx::TensorProto* initializer = nullptr;
for (const auto* input_def : node.InputDefs()) {
// Important note: when an initializer appears in the graph's input, this input will not be considered constant,
// because it can be overriden by the user at runtime. For constant folding to be applied, the initializer should not
// appear in the graph's inputs (that is the only way to guarantee it will always be constant).
if (!graph.GetInitializedTensor(input_def->Name(), initializer) || HasGraphInput(graph, input_def)) {
return false;
}
}
return true;
}
size_t RemoveNodeOutputEdges(Graph& graph, Node& node) {
std::vector<std::tuple<NodeIndex, int, int>> edges_to_remove;
for (auto it = node.OutputEdgesBegin(); it != node.OutputEdgesEnd(); ++it) {
edges_to_remove.emplace_back(std::make_tuple(it->GetNode().Index(),
it->GetSrcArgIndex(),
it->GetDstArgIndex()));
}
for (auto& edge_to_remove : edges_to_remove) {
graph.RemoveEdge(node.Index(),
std::get<0>(edge_to_remove),
std::get<1>(edge_to_remove),
std::get<2>(edge_to_remove));
}
return edges_to_remove.size();
}
} // namespace utils
} // namespace onnxruntime

View file

@ -40,5 +40,16 @@ bool GetRepeatedNodeAttributeValues(const Node& node,
/** Remove the given single-input-single-output Node from the Graph. */
bool RemoveSingleInSingleOutNode(Graph& graph, Node& node);
/** Returns true if the graph has the given input.*/
bool HasGraphInput(const Graph& graph, const NodeArg* input);
/** Checks if the given node has only constant inputs (initializers). */
bool IsConstantInputsNode(const Graph& graph, const Node& node);
/** Remove all output edges from the given Node of the Graph.
This should probably be elevated to the Graph API eventually. */
size_t RemoveNodeOutputEdges(Graph& graph, Node& node);
} // namespace utils
} // namespace onnxruntime

View file

@ -0,0 +1,87 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/constant_folding.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/optimizer_execution_frame.h"
#include "core/framework/op_kernel.h"
#include "core/framework/ml_value.h"
namespace onnxruntime {
Status ConstantFolding::Apply(Graph& graph, Node& node, bool& modified, bool& deleted) {
// TODO Check if we need default transformers any more. I dont think we do...
// Create execution frame for executing constant nodes.
OptimizerExecutionFrame::Info info({&node}, graph.GetAllInitializedTensors());
std::vector<int> fetch_mlvalue_idxs;
for (const auto* node_out : node.OutputDefs()) {
fetch_mlvalue_idxs.push_back(info.GetMLValueIndex(node_out->Name()));
}
OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs);
auto* kernel = info.GetKernel(node.Index());
OpKernelContext op_kernel_context(&frame, kernel, ::onnxruntime::logging::LoggingManager::DefaultLogger());
kernel->Compute(&op_kernel_context);
std::vector<MLValue> fetches;
frame.GetOutputs(fetches);
// Go over all output node args and substitute them with the newly computed tensors, which will be
// added to the graph as initializers.
ORT_ENFORCE(fetches.size() == node.OutputDefs().size());
for (int fetch_idx = 0; fetch_idx < fetches.size(); ++fetch_idx) {
MLValue& mlvalue = fetches[fetch_idx];
// Build the TensorProto that corresponds to the computed MLValue and add it as initializer to the graph.
ONNX_NAMESPACE::TensorProto out_tensorproto;
const auto* constant_arg_out = node.OutputDefs()[fetch_idx];
BuildTensorProtoForInitializer(mlvalue, *constant_arg_out, out_tensorproto);
graph.AddInitializedTensor(out_tensorproto);
}
// Remove the output edges of the constant node and then remove the node itself.
utils::RemoveNodeOutputEdges(graph, node);
graph.RemoveNode(node.Index());
// The output nodes already have the right input arg, since we used the same name in the initializer.
// We could remove unused graph initializers here, but Graph::Resolve() will take care of it.
modified = deleted = true;
return Status::OK();
} // namespace onnxruntime
bool ConstantFolding::SatisfyCondition(const Graph& graph, const Node& node) {
return OpTypeCondition(node) && utils::IsConstantInputsNode(graph, node);
}
bool ConstantFolding::OpTypeCondition(const Node& node) {
return excluded_op_types_.find(node.OpType()) == excluded_op_types_.end();
}
void ConstantFolding::BuildTensorProtoForInitializer(const MLValue& mlvalue,
const NodeArg& constant_node_arg,
ONNX_NAMESPACE::TensorProto& tensorproto) {
ORT_ENFORCE(mlvalue.IsTensor());
const Tensor& out_tensor = mlvalue.Get<Tensor>();
// Set name, dimensions, type, and data of the TensorProto.
tensorproto.set_name(constant_node_arg.Name());
for (auto& dim : out_tensor.Shape().GetDims()) {
tensorproto.add_dims(dim);
}
auto tensorproto_type = constant_node_arg.TypeAsProto()->tensor_type().elem_type();
tensorproto.set_data_type(tensorproto_type);
auto tensor_shape_size = out_tensor.Shape().Size();
auto data_size = out_tensor.DataType()->Size() * tensor_shape_size;
tensorproto.set_raw_data(out_tensor.DataRaw(out_tensor.DataType()), data_size);
}
} // namespace onnxruntime

View file

@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/rewrite_rule.h"
#include "core/framework/ml_value.h"
namespace onnxruntime {
/**
@class ConstantFolding
Rewrite rule that performs constant folding to the graph.
The rule gets applied to nodes that have only initializers as inputs. It statically computes these
nodes and replaces their output with an initializer that corresponds to the result of the computation.
*/
class ConstantFolding : public RewriteRule {
public:
ConstantFolding() noexcept : RewriteRule("ConstantFolding", "Constant folding") {}
private:
/** Constant folding will not be applied to nodes whose op_type is included in this set.
All non-deterministic operators should be included in this set. */
const std::unordered_set<std::string> excluded_op_types_ =
{"RandomUniform", "RandomNormal", "RandomUniformLike", "RandomNormalLike", "Multinomial"};
bool SatisfyCondition(const Graph& graph, const Node& node) override;
bool OpTypeCondition(const Node& node) override;
Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) override;
/** Create a TensorProto that has the same value as the given MLValue
and the same type and dimensions as the given NodeArg. */
void BuildTensorProtoForInitializer(const MLValue& mlvalue,
const NodeArg& constant_node_arg,
ONNX_NAMESPACE::TensorProto& tensorproto);
};
} // namespace onnxruntime

View file

@ -44,7 +44,6 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
std::deque<onnxruntime::NodeIndex> removed_nodes;
for (auto index : order) {
auto node = graph.GetNode(index);
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level));
if (!utils::IsSupportedOptypeVersionAndDomain(*node, "Conv", 1) || node->GetOutputEdgesCount() != 1) {

View file

@ -13,7 +13,7 @@ Status ConvMulFusion::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int g
std::vector<onnxruntime::NodeIndex> removed_nodes;
for (auto& node : graph.Nodes()) {
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
if (!utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) || node.GetOutputEdgesCount() != 1) {
continue;
}

View file

@ -12,7 +12,10 @@ namespace onnxruntime {
namespace {
bool IsFusableActivation(const Node& node) {
return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6);
return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) ||
utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) ||
utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) ||
utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6);
}
void HandleActivationNodeEdges(Graph& g, const Node& act, Node& fused_gemm) {

View file

@ -8,13 +8,13 @@ using namespace ::onnxruntime::common;
namespace onnxruntime {
Status GraphTransformer::Apply(Graph& graph, bool& modified) const {
// the Graph should be in a good state prior this being called, so there should be no need to call Resolve here
// The Graph should be in a good state prior this being called, so there should be no need to call Resolve here.
// ORT_RETURN_IF_ERROR(graph.Resolve());
auto status = ApplyImpl(graph, modified, 0);
ORT_RETURN_IF_ERROR(status);
// at least currently, some transformers (InsertCastTransformer and MemcpyTransformer need this to be called
// At least currently, some transformers (InsertCastTransformer and MemcpyTransformer) need this to be called
// after they complete to put the graph back into a valid state for the next transformer.
if (modified) {
status = graph.Resolve();
@ -23,12 +23,21 @@ Status GraphTransformer::Apply(Graph& graph, bool& modified) const {
return status;
}
Status RuleBasedGraphTransformer::Register(const std::string& op_type, std::unique_ptr<RewriteRule> rule) {
if (HasRules(op_type)) {
op_to_rules_[op_type] = std::vector<std::unique_ptr<RewriteRule>>();
}
Status RuleBasedGraphTransformer::Register(std::unique_ptr<RewriteRule> rule) {
rules_.push_back(std::move(rule));
return Status::OK();
}
op_to_rules_[op_type].push_back(std::move(rule));
Status RuleBasedGraphTransformer::ApplyRulesOnNode(Graph& graph, Node& node,
const std::vector<std::unique_ptr<RewriteRule>>& rules,
bool& modified, bool& deleted) const {
for (const auto& rule : rules) {
ORT_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, node, modified, deleted));
if (deleted) {
modified = true; // should be set by rewriter but in case it wasn't...
break;
}
}
return Status::OK();
}
@ -42,19 +51,10 @@ Status TopDownRuleBasedTransformer::ApplyImpl(Graph& graph, bool& modified, int
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
}
// Get the rules that should be fired for this node.
const std::vector<std::unique_ptr<RewriteRule>>* rules = GetRewriteRules(node->OpType());
// Apply rewrite rules on current node, then recursively apply rules to subgraphs (if any).
// Stop further rule application for the current node, if the node gets removed by a rule.
bool deleted = false;
if (rules) {
for (const auto& rule : *rules) {
ORT_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, *node, modified, deleted));
if (deleted) {
modified = true; // should be set by rewriter but in case it wasn't...
break;
}
}
}
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, GetRewriteRules(), modified, deleted));
if (!deleted) {
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level));

View file

@ -4,6 +4,7 @@
#pragma once
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/constant_folding.h"
namespace onnxruntime {
// Manages a list of graph transformers. It is initialized with a list of graph
@ -11,7 +12,7 @@ namespace onnxruntime {
class GraphTransformerManager {
public:
explicit GraphTransformerManager(unsigned steps) noexcept : steps_(steps) {
// TODO: Register default transformers.
// Register default transformers.
}
// Register a graph transformer.

View file

@ -18,8 +18,12 @@ Status EliminateIdentity::Apply(Graph& graph, Node& node, bool& modified, bool&
return Status::OK();
}
bool EliminateIdentity::SatisfyCondition(const Node& node) {
return utils::IsSingleInSingleOutNode(node);
bool EliminateIdentity::SatisfyCondition(const Graph& /*graph*/, const Node& node) {
return OpTypeCondition(node) && utils::IsSingleInSingleOutNode(node);
}
bool EliminateIdentity::OpTypeCondition(const Node& node) {
return node.OpType() == included_op_type_;
}
} // namespace onnxruntime

View file

@ -13,7 +13,12 @@ class EliminateIdentity : public RewriteRule {
EliminateIdentity() noexcept : RewriteRule("EliminateIdentity", "Eliminate identity node") {}
private:
bool SatisfyCondition(const Node& node) override;
/** Apply rule when op type is one of the following. */
const std::string included_op_type_ = "Identity";
bool SatisfyCondition(const Graph& graph, const Node& node) override;
bool OpTypeCondition(const Node& node) override;
Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) override;
};

View file

@ -16,7 +16,11 @@ Status EliminateSlice::Apply(Graph& graph, Node& node, bool& modified, bool& rem
return Status::OK();
}
bool EliminateSlice::SatisfyCondition(const Node& node) {
bool EliminateSlice::SatisfyCondition(const Graph& /*graph*/, const Node& node) {
if (!OpTypeCondition(node)) {
return false;
}
// At the moment, we eliminate a slice operator only if it has a single input and a single output.
if (!utils::IsSingleInSingleOutNode(node)) {
return false;
@ -50,4 +54,8 @@ bool EliminateSlice::SatisfyCondition(const Node& node) {
return true;
}
bool EliminateSlice::OpTypeCondition(const Node& node) {
return node.OpType() == included_op_type_;
}
} // namespace onnxruntime

View file

@ -13,7 +13,12 @@ class EliminateSlice : public RewriteRule {
EliminateSlice() noexcept : RewriteRule("EliminateSlice", "Eliminate slice node") {}
private:
bool SatisfyCondition(const Node& node) override;
/** Apply rule when op type is one of the following. */
const std::string included_op_type_ = "Slice";
bool SatisfyCondition(const Graph& graph, const Node& node) override;
bool OpTypeCondition(const Node& node) override;
Status Apply(Graph& graph, Node& node, bool& modified, bool& removed) override;
};

View file

@ -32,7 +32,7 @@ namespace test {
static const std::string MODEL_FOLDER = "testdata/transform/";
// Return a map with the number of occurrences of each operator in the graph.
// Returns a map with the number of occurrences of each operator in the graph.
// Helper function to check that the graph transformations have been successfully applied.
std::map<std::string, int> CountOpsInGraph(const Graph& graph) {
std::map<std::string, int> op_to_count;
@ -53,7 +53,7 @@ TEST(GraphTransformationTests, IdentityElimination) {
std::unique_ptr<TopDownRuleBasedTransformer> rule_transformer =
std::make_unique<TopDownRuleBasedTransformer>("RuleTransformer1", "First rule transformer");
rule_transformer->Register("Identity", std::make_unique<EliminateIdentity>());
rule_transformer->Register(std::make_unique<EliminateIdentity>());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(std::move(rule_transformer));
ASSERT_TRUE(graph_transformation_mgr.ApplyAll(graph).IsOK());
@ -72,7 +72,7 @@ TEST(GraphTransformationTests, SliceElimination) {
std::unique_ptr<TopDownRuleBasedTransformer> rule_transformer =
std::make_unique<TopDownRuleBasedTransformer>("RuleTransformer1", "First rule transformer");
rule_transformer->Register("Slice", std::make_unique<EliminateSlice>());
rule_transformer->Register(std::make_unique<EliminateSlice>());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(std::move(rule_transformer));
ASSERT_TRUE(graph_transformation_mgr.ApplyAll(graph).IsOK());
@ -81,6 +81,27 @@ TEST(GraphTransformationTests, SliceElimination) {
ASSERT_TRUE(op_to_count["Slice"] == 3);
}
TEST(GraphTransformationTests, ConstantFolding) {
string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx";
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(model_uri, model).IsOK());
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Unsqueeze"] == 2);
std::unique_ptr<TopDownRuleBasedTransformer> rule_transformer =
std::make_unique<TopDownRuleBasedTransformer>("RuleTransformer1", "First rule transformer");
rule_transformer->Register(std::make_unique<ConstantFolding>());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(std::move(rule_transformer));
ASSERT_TRUE(graph_transformation_mgr.ApplyAll(graph).IsOK());
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Unsqueeze"] == 0);
}
TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) {
string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx";