Conv(Add|Mul|BN)Fusion as rewrite rules (#863)

* Converted ConvAddFusion, ConvMulFusion, and ConvBNFusion to rewrite rules
* Extended graph_utils::RemoveNode
* Introduced RewriteRuleEffect enum
This commit is contained in:
Konstantinos Karanasos 2019-05-01 13:23:29 -07:00 committed by GitHub
parent 0ad940027c
commit feab3088fb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 489 additions and 490 deletions

View file

@ -35,6 +35,18 @@ If the list of op types is left empty, that rule will be triggered for every op
*/
class RewriteRule {
public:
/**
@class RewriteRuleEffect
Class used to indicate the effect of rule application on a graph's node.
*/
enum class RewriteRuleEffect : uint8_t {
kNone, // The rewrite rule has not modified the graph.
kUpdatedCurrentNode, // The rewrite rule updated (but did not remove) the node on which it was triggered.
kRemovedCurrentNode, // The rewrite rule removed the node on which it was triggered.
kModifiedRestOfGraph // The rewrite rule modified nodes other than the one it was triggered on.
};
RewriteRule(const std::string& name) : name_(name) {}
virtual ~RewriteRule() = default;
@ -52,11 +64,10 @@ class RewriteRule {
/** Checks if the condition of the rule is satisfied, and if so applies the body of the rule.
@param[in] graph The Graph.
@param[in] node The Node to apply the rewrite to.
@param[out] modified Set to indicate whether the node was modified or not.
@param[out] deleted Set to indicate if the node was deleted.
@param[out] rule_effect Enum to indicate if and how the graph was modified as a result of the rule application.
@returns Status indicating success or providing error information */
common::Status CheckConditionAndApply(Graph& graph, Node& node, bool& modified, bool& deleted) {
return SatisfyCondition(graph, node) ? Apply(graph, node, modified, deleted) : Status::OK();
common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
return SatisfyCondition(graph, node) ? Apply(graph, node, rule_effect) : Status::OK();
}
private:
@ -72,8 +83,7 @@ class RewriteRule {
/** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place.
The return-value of node may be different from the input-value due to rewriting.
The value of "modified" indicates if the graph was modified or not.
The value of "deleted" indicates if the node was deleted or not. */
virtual common::Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) = 0;
The value of "rule_effect" indicates whether and how the graph was modified by the rule. */
virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) = 0;
};
} // namespace onnxruntime

View file

@ -58,14 +58,16 @@ class RuleBasedGraphTransformer : public GraphTransformer {
@param[in] graph The Graph.
@param[in] node The Node to apply the rules to.
@param[in] rules The vector of RewriteRules that will be applied to the Node.
@param[out] modified Set to indicate whether the node was modified or not.
@param[out] deleted Set to indicate if the node was deleted.
@param[out] rule_effect Enum that indicates whether and how the graph was modified as a result of
applying rules on this node.
@returns Status indicating success or providing error information. */
common::Status ApplyRulesOnNode(Graph& graph, Node& node,
const std::vector<std::unique_ptr<RewriteRule>>& rules,
bool& modified, bool& deleted) const;
RewriteRule::RewriteRuleEffect& rule_effect) const;
private:
using RuleEffect = RewriteRule::RewriteRuleEffect;
// Map that associates a node's op type with the vector of rules that are registered to be triggered for that node.
std::unordered_map<std::string, std::vector<std::unique_ptr<RewriteRule>>> op_type_to_rules_;
// Rules that will be evaluated regardless of the op type of the node.

View file

@ -353,19 +353,24 @@ const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const s
return iter == attrs.end() ? nullptr : &iter->second;
}
bool RemoveSingleInputNode(Graph& graph, Node& node) {
// Cannot remove a node with multiple output NodeArgs (multiple output edges is fine), neither
// a node whose output is also a graph output.
if (!IsSingleInSingleOutNode(node) ||
bool RemoveNode(Graph& graph, Node& node) {
// Cannot remove a node with implicit inputs, with multiple output NodeArgs (multiple output edges is fine),
// or whose output is also a graph output.
if (node.ImplicitInputDefs().size() > 0 ||
node.OutputDefs().size() != 1 ||
graph.IsNodeOutputsInGraphOutputs(node)) {
return false;
}
// If the single input comes from another node (initializers are not connected with edges to nodes).
if (node.GetInputEdgesCount() == 1) {
// If there is a single input edge from another node (initializers are not connected with edges to nodes).
return RemoveNodeWithSingleNodeIn(graph, node);
} else {
} else if (node.InputDefs().size() == 1) {
// If a single initializer is the only input.
return RemoveNodeWithSingleInitializerIn(graph, node);
} else {
// No other node removal is supported, because there will be no way to connect its inputs to its outputs.
return false;
}
}

View file

@ -64,10 +64,14 @@ bool GetRepeatedNodeAttributeValues(const Node& node,
Status ForAllMutableSubgraphs(Graph& main_graph, std::function<Status(Graph&)> func);
Status ForAllSubgraphs(const Graph& main_graph, std::function<Status(const Graph&)> func);
/** Removes the given single-input Node from the Graph. The single input might be either
another node or an initializer, but not an implicit input. The node should have a single
output but can have multiple output edges. */
bool RemoveSingleInputNode(Graph& graph, Node& node);
/** Removes the given Node from the Graph and keeps Graph consistent by rebuilding needed connections.
We support the removal of the Node if it has no implicit inputs and a single output (but it can have multiple
output edges). As for the Node's inputs, we support the following cases:
- If the Node has a single incoming node (and possibly multiple initializers), we can remove the Node and
connect its incoming node to its outgoing nodes.
- If the Node has a single initializer as input, we remove the Node and feed the initializer as input to its
output nodes. */
bool RemoveNode(Graph& graph, Node& node);
/** Removes all output edges from the given Node of the Graph.
This should probably be elevated to the Graph API eventually. */

View file

@ -9,136 +9,108 @@ using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
Status ConvAddFusion::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const {
std::vector<onnxruntime::NodeIndex> removed_nodes;
for (auto& node : graph.Nodes()) {
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified) {
auto& conv_node = node;
const auto& add_node = *conv_node.OutputNodesBegin();
const auto& conv_inputs = conv_node.InputDefs();
const auto& add_inputs = add_node.InputDefs();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
node.GetOutputEdgesCount() != 1) {
continue;
}
const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
const Node& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", 7) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType() ||
next_node.GetInputEdgesCount() != 1 ||
graph.IsNodeOutputsInGraphOutputs(next_node)) {
continue;
}
const ONNX_NAMESPACE::TensorProto* add_B_tensor_proto = nullptr;
graph.GetInitializedTensor(add_inputs[1]->Name(), add_B_tensor_proto);
auto& conv_node = node;
const Node& add_node = next_node;
const auto& conv_inputs = conv_node.InputDefs();
const auto& add_inputs = add_node.InputDefs();
const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
const ONNX_NAMESPACE::TensorProto* add_B_tensor_proto = nullptr;
graph.GetInitializedTensor(add_inputs[1]->Name(), add_B_tensor_proto);
// Currently, fusion is only supported for float or double data type.
if (!Initializer::IsSupportedDataType(add_B_tensor_proto) ||
conv_W_tensor_proto->dims_size() < 4 ||
add_B_tensor_proto->dims_size() != conv_W_tensor_proto->dims_size() - 1 ||
conv_W_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) {
continue;
}
// The dimensions of add_B should be equal to 1 except first dimension.
bool flag = false;
for (int i = 1; i < add_B_tensor_proto->dims_size(); i++) {
if (add_B_tensor_proto->dims(i) != 1) {
flag = true;
break;
}
}
if (flag) {
continue;
}
const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr;
if (conv_inputs.size() == 3) {
graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto);
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
conv_B_tensor_proto->data_type() != add_B_tensor_proto->data_type() ||
conv_B_tensor_proto->dims_size() != 1 ||
conv_B_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) {
continue;
}
auto conv_B = std::make_unique<Initializer>(conv_B_tensor_proto);
auto add_B = std::make_unique<Initializer>(add_B_tensor_proto);
if (conv_B->size() != add_B->size()) {
continue;
}
// Calculate new value of initializers of conv node
conv_B->add(*add_B);
// Create new initializers of conv
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto;
conv_B->ToProto(&new_conv_B_tensor_proto);
// Replace initializers of conv node
graph.RemoveInitializedTensor(conv_inputs[2]->Name());
graph.AddInitializedTensor(new_conv_B_tensor_proto);
} else {
NodeArg* add_B_node_arg = graph.GetNodeArg(add_B_tensor_proto->name());
if (add_B_node_arg == nullptr) {
continue;
}
// Update shape of tensor proto
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*add_B_tensor_proto);
int64_t dim = conv_W_tensor_proto->dims(0);
new_conv_B_tensor_proto.clear_dims();
new_conv_B_tensor_proto.add_dims(dim);
graph.RemoveInitializedTensor(add_B_tensor_proto->name());
graph.AddInitializedTensor(new_conv_B_tensor_proto);
// Update shape of NodeArg
TensorShapeProto shape;
shape.add_dim()->set_dim_value(dim);
add_B_node_arg->SetShape(shape);
conv_node.MutableInputDefs().push_back(add_B_node_arg);
conv_node.MutableInputArgsCount()[2] = 1;
}
// Replace the input of the node following add node
const NodeArg* add_output_def = add_node.OutputDefs()[0];
NodeArg* conv_output_def = conv_node.MutableOutputDefs()[0];
for (auto it = add_node.OutputNodesBegin(); it != add_node.OutputNodesEnd(); ++it) {
auto output_node = graph.GetNode((*it).Index());
if (!output_node) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
}
auto& input_defs = output_node->MutableInputDefs();
for (auto& def : input_defs) {
if (def == add_output_def) {
def = conv_output_def;
}
}
}
removed_nodes.push_back(add_node.Index());
// Currently, fusion is only supported for float or double data type.
if (!Initializer::IsSupportedDataType(add_B_tensor_proto) ||
conv_W_tensor_proto->dims_size() < 4 ||
add_B_tensor_proto->dims_size() != conv_W_tensor_proto->dims_size() - 1 ||
conv_W_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) {
return Status::OK();
}
for (auto i : removed_nodes) {
graph.RemoveNode(i);
// The dimensions of add_B should be equal to 1 except first dimension.
for (int i = 1; i < add_B_tensor_proto->dims_size(); i++) {
if (add_B_tensor_proto->dims(i) != 1) {
return Status::OK();
}
}
if (!removed_nodes.empty()) {
modified = true;
const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr;
if (conv_inputs.size() == 3) {
graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto);
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
conv_B_tensor_proto->data_type() != add_B_tensor_proto->data_type() ||
conv_B_tensor_proto->dims_size() != 1 ||
conv_B_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) {
return Status::OK();
}
auto conv_B = std::make_unique<Initializer>(conv_B_tensor_proto);
auto add_B = std::make_unique<Initializer>(add_B_tensor_proto);
if (conv_B->size() != add_B->size()) {
return Status::OK();
}
// Calculate new value of initializers of conv node
conv_B->add(*add_B);
// Create new initializers of conv
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto;
conv_B->ToProto(&new_conv_B_tensor_proto);
// Replace initializers of conv node
graph.RemoveInitializedTensor(conv_inputs[2]->Name());
graph.AddInitializedTensor(new_conv_B_tensor_proto);
} else {
NodeArg* add_B_node_arg = graph.GetNodeArg(add_B_tensor_proto->name());
if (add_B_node_arg == nullptr) {
return Status::OK();
}
// Update shape of tensor proto
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*add_B_tensor_proto);
int64_t dim = conv_W_tensor_proto->dims(0);
new_conv_B_tensor_proto.clear_dims();
new_conv_B_tensor_proto.add_dims(dim);
graph.RemoveInitializedTensor(add_B_tensor_proto->name());
graph.AddInitializedTensor(new_conv_B_tensor_proto);
// Update shape of NodeArg
TensorShapeProto shape;
shape.add_dim()->set_dim_value(dim);
add_B_node_arg->SetShape(shape);
conv_node.MutableInputDefs().push_back(add_B_node_arg);
conv_node.MutableInputArgsCount()[2] = 1;
}
// Remove Add node.
auto* add_node_to_remove = graph.GetNode(add_node.Index());
if (graph_utils::RemoveNode(graph, *add_node_to_remove)) {
modified = RewriteRuleEffect::kModifiedRestOfGraph;
}
return Status::OK();
}
bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) ||
node.GetOutputEdgesCount() != 1) {
return false;
}
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", 7) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType() ||
next_node.GetInputEdgesCount() != 1 ||
graph.IsNodeOutputsInGraphOutputs(next_node)) {
return false;
}
return true;
}
} // namespace onnxruntime

View file

@ -3,16 +3,29 @@
#pragma once
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
class ConvAddFusion : public onnxruntime::GraphTransformer {
/**
@Class ConvAddFusion
Rewrite rule that fuses two Conv+Add nodes to a single Conv node.
It is attempted to be triggered only on nodes with op type "Conv".
*/
class ConvAddFusion : public RewriteRule {
public:
ConvAddFusion() noexcept : onnxruntime::GraphTransformer("ConvAddFusion") {}
ConvAddFusion() noexcept : RewriteRule("ConvAddFusion") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Conv"};
}
private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
bool SatisfyCondition(const Graph& graph, const Node& node) override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
};
} // namespace onnxruntime

View file

@ -9,184 +9,154 @@ using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
Status ConvBNFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
std::vector<onnxruntime::NodeIndex> removed_nodes;
for (auto& node : graph.Nodes()) {
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
auto& conv_node = node;
const Node& bn_node = *conv_node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
node.GetOutputEdgesCount() != 1) {
continue;
// Get value of attribute epsilon
const onnxruntime::NodeAttributes& attributes = bn_node.GetAttributes();
const ONNX_NAMESPACE::AttributeProto* attr = &(attributes.find("epsilon")->second);
if (attr == nullptr || attr->type() != AttributeProto_AttributeType_FLOAT) {
return Status::OK();
}
float epsilon = static_cast<float>(attr->f());
// Get initializers of BatchNormalization
const auto& bn_inputs = bn_node.InputDefs();
const ONNX_NAMESPACE::TensorProto* bn_scale_tensor_proto = nullptr;
graph.GetInitializedTensor(bn_inputs[1]->Name(), bn_scale_tensor_proto);
const ONNX_NAMESPACE::TensorProto* bn_B_tensor_proto = nullptr;
graph.GetInitializedTensor(bn_inputs[2]->Name(), bn_B_tensor_proto);
const ONNX_NAMESPACE::TensorProto* bn_mean_tensor_proto = nullptr;
graph.GetInitializedTensor(bn_inputs[3]->Name(), bn_mean_tensor_proto);
const ONNX_NAMESPACE::TensorProto* bn_var_tensor_proto = nullptr;
graph.GetInitializedTensor(bn_inputs[4]->Name(), bn_var_tensor_proto);
const auto& conv_inputs = conv_node.InputDefs();
const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
// Currently, fusion is only supported for float or double data type.
if (!Initializer::IsSupportedDataType(bn_scale_tensor_proto) ||
!Initializer::IsSupportedDataType(bn_B_tensor_proto) ||
!Initializer::IsSupportedDataType(bn_mean_tensor_proto) ||
!Initializer::IsSupportedDataType(bn_var_tensor_proto) ||
!Initializer::IsSupportedDataType(conv_W_tensor_proto) ||
bn_scale_tensor_proto->dims_size() != 1 ||
bn_B_tensor_proto->dims_size() != 1 ||
bn_mean_tensor_proto->dims_size() != 1 ||
bn_var_tensor_proto->dims_size() != 1 ||
bn_scale_tensor_proto->dims(0) != bn_B_tensor_proto->dims(0) ||
bn_B_tensor_proto->dims(0) != bn_mean_tensor_proto->dims(0) ||
bn_mean_tensor_proto->dims(0) != bn_var_tensor_proto->dims(0) ||
bn_scale_tensor_proto->data_type() != bn_B_tensor_proto->data_type() ||
bn_B_tensor_proto->data_type() != bn_mean_tensor_proto->data_type() ||
bn_mean_tensor_proto->data_type() != bn_var_tensor_proto->data_type() ||
conv_W_tensor_proto->data_type() != bn_scale_tensor_proto->data_type() ||
!(conv_W_tensor_proto->dims_size() > 2 && conv_W_tensor_proto->dims(0) == bn_scale_tensor_proto->dims(0))) {
return Status::OK();
}
auto bn_scale = std::make_unique<Initializer>(bn_scale_tensor_proto);
auto bn_B = std::make_unique<Initializer>(bn_B_tensor_proto);
auto bn_mean = std::make_unique<Initializer>(bn_mean_tensor_proto);
auto bn_var = std::make_unique<Initializer>(bn_var_tensor_proto);
auto conv_W = std::make_unique<Initializer>(conv_W_tensor_proto);
const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr;
std::unique_ptr<Initializer> conv_B = nullptr;
if (conv_inputs.size() == 3) {
if (!graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto)) {
return Status::OK();
}
const Node& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BatchNormalization", 7) ||
next_node.GetInputEdgesCount() != 1 ||
graph.IsNodeOutputsInGraphOutputs(next_node) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
continue;
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
conv_B_tensor_proto->dims_size() != 1 ||
conv_B_tensor_proto->dims(0) != bn_B_tensor_proto->dims(0) ||
conv_B_tensor_proto->data_type() != bn_B_tensor_proto->data_type()) {
return Status::OK();
}
conv_B = std::make_unique<Initializer>(conv_B_tensor_proto);
}
auto& conv_node = node;
const Node& bn_node = next_node;
// Calculate new value of initializers of conv node
bn_var->add(epsilon);
bn_var->sqrt();
bn_scale->div(*bn_var);
conv_W->scale_by_axis(*bn_scale, 1);
// Get value of attribute group
const onnxruntime::NodeAttributes& conv_attributes = conv_node.GetAttributes();
const ONNX_NAMESPACE::AttributeProto* group_attr = &(conv_attributes.find("group")->second);
if (group_attr != nullptr &&
group_attr->type() == AttributeProto_AttributeType_INT &&
group_attr->has_i() && group_attr->i() != 1) {
continue;
if (conv_inputs.size() == 3) {
conv_B->sub(*bn_mean);
conv_B->mul(*bn_scale);
conv_B->add(*bn_B);
} else {
bn_mean->mul(*bn_scale);
bn_B->sub(*bn_mean);
}
// Create new initializers of conv
ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto);
conv_W->ToProto(&new_conv_W_tensor_proto);
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto;
NodeArg* bn_B_node_arg = nullptr;
if (conv_inputs.size() == 3) {
conv_B->ToProto(&new_conv_B_tensor_proto);
} else {
bn_B->ToProto(&new_conv_B_tensor_proto);
bn_B_node_arg = graph.GetNodeArg(bn_B_tensor_proto->name());
if (bn_B_node_arg == nullptr) {
return Status::OK();
}
}
// Get value of attribute epsilon
const onnxruntime::NodeAttributes& attributes = bn_node.GetAttributes();
const ONNX_NAMESPACE::AttributeProto* attr = &(attributes.find("epsilon")->second);
if (attr == nullptr || attr->type() != AttributeProto_AttributeType_FLOAT) {
continue;
}
float epsilon = static_cast<float>(attr->f());
// Get initializers of BatchNormalization
const auto& bn_inputs = bn_node.InputDefs();
const ONNX_NAMESPACE::TensorProto* bn_scale_tensor_proto = nullptr;
graph.GetInitializedTensor(bn_inputs[1]->Name(), bn_scale_tensor_proto);
const ONNX_NAMESPACE::TensorProto* bn_B_tensor_proto = nullptr;
graph.GetInitializedTensor(bn_inputs[2]->Name(), bn_B_tensor_proto);
const ONNX_NAMESPACE::TensorProto* bn_mean_tensor_proto = nullptr;
graph.GetInitializedTensor(bn_inputs[3]->Name(), bn_mean_tensor_proto);
const ONNX_NAMESPACE::TensorProto* bn_var_tensor_proto = nullptr;
graph.GetInitializedTensor(bn_inputs[4]->Name(), bn_var_tensor_proto);
const auto& conv_inputs = conv_node.InputDefs();
const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
// Currently, fusion is only supported for float or double data type.
if (!Initializer::IsSupportedDataType(bn_scale_tensor_proto) ||
!Initializer::IsSupportedDataType(bn_B_tensor_proto) ||
!Initializer::IsSupportedDataType(bn_mean_tensor_proto) ||
!Initializer::IsSupportedDataType(bn_var_tensor_proto) ||
!Initializer::IsSupportedDataType(conv_W_tensor_proto) ||
bn_scale_tensor_proto->dims_size() != 1 ||
bn_B_tensor_proto->dims_size() != 1 ||
bn_mean_tensor_proto->dims_size() != 1 ||
bn_var_tensor_proto->dims_size() != 1 ||
bn_scale_tensor_proto->dims(0) != bn_B_tensor_proto->dims(0) ||
bn_B_tensor_proto->dims(0) != bn_mean_tensor_proto->dims(0) ||
bn_mean_tensor_proto->dims(0) != bn_var_tensor_proto->dims(0) ||
bn_scale_tensor_proto->data_type() != bn_B_tensor_proto->data_type() ||
bn_B_tensor_proto->data_type() != bn_mean_tensor_proto->data_type() ||
bn_mean_tensor_proto->data_type() != bn_var_tensor_proto->data_type() ||
conv_W_tensor_proto->data_type() != bn_scale_tensor_proto->data_type() ||
!(conv_W_tensor_proto->dims_size() > 2 && conv_W_tensor_proto->dims(0) == bn_scale_tensor_proto->dims(0))) {
continue;
}
auto bn_scale = std::make_unique<Initializer>(bn_scale_tensor_proto);
auto bn_B = std::make_unique<Initializer>(bn_B_tensor_proto);
auto bn_mean = std::make_unique<Initializer>(bn_mean_tensor_proto);
auto bn_var = std::make_unique<Initializer>(bn_var_tensor_proto);
auto conv_W = std::make_unique<Initializer>(conv_W_tensor_proto);
const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr;
std::unique_ptr<Initializer> conv_B = nullptr;
if (conv_inputs.size() == 3) {
if (!graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto))
continue;
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
conv_B_tensor_proto->dims_size() != 1 ||
conv_B_tensor_proto->dims(0) != bn_B_tensor_proto->dims(0) ||
conv_B_tensor_proto->data_type() != bn_B_tensor_proto->data_type()) {
continue;
}
conv_B = std::make_unique<Initializer>(conv_B_tensor_proto);
}
// Calculate new value of initializers of conv node
bn_var->add(epsilon);
bn_var->sqrt();
bn_scale->div(*bn_var);
conv_W->scale_by_axis(*bn_scale, 1);
if (conv_inputs.size() == 3) {
conv_B->sub(*bn_mean);
conv_B->mul(*bn_scale);
conv_B->add(*bn_B);
} else {
bn_mean->mul(*bn_scale);
bn_B->sub(*bn_mean);
}
// Create new initializers of conv
ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto);
conv_W->ToProto(&new_conv_W_tensor_proto);
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto;
NodeArg* bn_B_node_arg = nullptr;
if (conv_inputs.size() == 3) {
conv_B->ToProto(&new_conv_B_tensor_proto);
} else {
bn_B->ToProto(&new_conv_B_tensor_proto);
bn_B_node_arg = graph.GetNodeArg(bn_B_tensor_proto->name());
if (bn_B_node_arg == nullptr) {
continue;
}
}
// Replace initializers of conv node
graph.RemoveInitializedTensor(conv_W_tensor_proto->name());
if (conv_inputs.size() == 3) {
// Replace initializers of conv node
graph.RemoveInitializedTensor(conv_W_tensor_proto->name());
if (conv_inputs.size() == 3) {
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 6011) // Not deferencing null pointer. conv_B_tensor_proto is set on line 93
#endif
graph.RemoveInitializedTensor(conv_B_tensor_proto->name());
graph.RemoveInitializedTensor(conv_B_tensor_proto->name());
#ifdef _MSC_VER
#pragma warning(pop)
#endif
} else {
graph.RemoveInitializedTensor(bn_B_tensor_proto->name());
conv_node.MutableInputDefs().push_back(bn_B_node_arg);
conv_node.MutableInputArgsCount()[2] = 1;
}
graph.AddInitializedTensor(new_conv_W_tensor_proto);
graph.AddInitializedTensor(new_conv_B_tensor_proto);
// Replace the input of the nodes following batch normalization node
const NodeArg* bn_output_def = bn_node.OutputDefs()[0];
NodeArg* conv_output_def = conv_node.MutableOutputDefs()[0];
for (auto it = bn_node.OutputNodesBegin(); it != bn_node.OutputNodesEnd(); ++it) {
auto output_node = graph.GetNode((*it).Index());
if (!output_node) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
}
auto& input_defs = output_node->MutableInputDefs();
for (auto& def : input_defs) {
if (def == bn_output_def) {
def = conv_output_def;
}
}
}
removed_nodes.push_back(bn_node.Index());
} else {
graph.RemoveInitializedTensor(bn_B_tensor_proto->name());
conv_node.MutableInputDefs().push_back(bn_B_node_arg);
conv_node.MutableInputArgsCount()[2] = 1;
}
graph.AddInitializedTensor(new_conv_W_tensor_proto);
graph.AddInitializedTensor(new_conv_B_tensor_proto);
for (auto i : removed_nodes) {
graph.RemoveNode(i);
}
if (!removed_nodes.empty()) {
modified = true;
// Remove BN node.
auto* bn_node_to_remove = graph.GetNode(bn_node.Index());
if (graph_utils::RemoveNode(graph, *bn_node_to_remove)) {
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
}
return Status::OK();
}
bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) ||
node.GetOutputEdgesCount() != 1) {
return false;
}
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BatchNormalization", 7) ||
next_node.GetInputEdgesCount() != 1 ||
graph.IsNodeOutputsInGraphOutputs(next_node) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
return false;
}
return true;
}
} // namespace onnxruntime

View file

@ -3,15 +3,29 @@
#pragma once
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
class ConvBNFusion : public onnxruntime::GraphTransformer {
/**
@Class ConvBNFusion
Rewrite rule that fuses two Conv+BN nodes to a single Conv node.
It is attempted to be triggered only on nodes with op type "Conv".
*/
class ConvBNFusion : public RewriteRule {
public:
ConvBNFusion() noexcept : onnxruntime::GraphTransformer("ConvBNFusion") {}
ConvBNFusion() noexcept : RewriteRule("ConvBNFusion") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Conv"};
}
private:
Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const override;
bool SatisfyCondition(const Graph& graph, const Node& node) override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
};
} // namespace onnxruntime

View file

@ -9,135 +9,107 @@ using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
Status ConvMulFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
std::vector<onnxruntime::NodeIndex> removed_nodes;
for (auto& node : graph.Nodes()) {
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
node.GetOutputEdgesCount() != 1) {
continue;
Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
auto& conv_node = node;
const auto& mul_node = *conv_node.OutputNodesBegin();
const auto& conv_inputs = conv_node.InputDefs();
const auto& mul_inputs = mul_node.InputDefs();
const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
const ONNX_NAMESPACE::TensorProto* mul_B_tensor_proto = nullptr;
graph.GetInitializedTensor(mul_inputs[1]->Name(), mul_B_tensor_proto);
if (!Initializer::IsSupportedDataType(conv_W_tensor_proto) ||
!Initializer::IsSupportedDataType(mul_B_tensor_proto) ||
conv_W_tensor_proto->data_type() != mul_B_tensor_proto->data_type() ||
conv_W_tensor_proto->dims_size() < 4 ||
!(mul_B_tensor_proto->dims_size() == 0 ||
(mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1 &&
conv_W_tensor_proto->dims(0) == mul_B_tensor_proto->dims(0)))) {
return Status::OK();
}
// The dimensions of mul_B should be equal to 1 except first dimension.
if (mul_B_tensor_proto->dims_size() != 0) {
for (int i = 1; i < mul_B_tensor_proto->dims_size(); i++) {
if (mul_B_tensor_proto->dims(i) != 1) {
return Status::OK();
}
}
}
const Node& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", 7) ||
next_node.GetInputEdgesCount() != 1 ||
graph.IsNodeOutputsInGraphOutputs(next_node) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
continue;
auto conv_W = std::make_unique<Initializer>(conv_W_tensor_proto);
auto mul_B = std::make_unique<Initializer>(mul_B_tensor_proto);
const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr;
std::unique_ptr<Initializer> conv_B = nullptr;
const bool is_3d = conv_inputs.size() == 3;
if (is_3d) {
if (!graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto))
return Status::OK();
if (conv_B_tensor_proto == nullptr)
return Status(ONNXRUNTIME, FAIL, "Internal error in ConvMulFusion. conv_B_tensor_proto is NULL");
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
conv_B_tensor_proto->data_type() != mul_B_tensor_proto->data_type() ||
conv_B_tensor_proto->dims_size() != 1 ||
(mul_B_tensor_proto->dims_size() != 0 && conv_B_tensor_proto->dims(0) != mul_B_tensor_proto->dims(0))) {
return Status::OK();
}
conv_B = std::make_unique<Initializer>(conv_B_tensor_proto);
}
auto& conv_node = node;
const Node& mul_node = next_node;
// Calculate new value of initializers of conv node
conv_W->scale_by_axis(*mul_B, 1);
const auto& conv_inputs = conv_node.InputDefs();
const auto& mul_inputs = mul_node.InputDefs();
const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
const ONNX_NAMESPACE::TensorProto* mul_B_tensor_proto = nullptr;
graph.GetInitializedTensor(mul_inputs[1]->Name(), mul_B_tensor_proto);
if (!Initializer::IsSupportedDataType(conv_W_tensor_proto) ||
!Initializer::IsSupportedDataType(mul_B_tensor_proto) ||
conv_W_tensor_proto->data_type() != mul_B_tensor_proto->data_type() ||
conv_W_tensor_proto->dims_size() < 4 ||
!(mul_B_tensor_proto->dims_size() == 0 ||
(mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1 &&
conv_W_tensor_proto->dims(0) == mul_B_tensor_proto->dims(0)))) {
continue;
}
// The dimensions of mul_B should be equal to 1 except first dimension.
if (conv_inputs.size() == 3) {
if (mul_B_tensor_proto->dims_size() != 0) {
bool flag = false;
for (int i = 1; i < mul_B_tensor_proto->dims_size(); i++) {
if (mul_B_tensor_proto->dims(i) != 1) {
flag = true;
break;
}
}
if (flag) {
continue;
}
conv_B->mul(*mul_B);
} else {
conv_B->scale_by_axis(*mul_B, 0);
}
auto conv_W = std::make_unique<Initializer>(conv_W_tensor_proto);
auto mul_B = std::make_unique<Initializer>(mul_B_tensor_proto);
const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr;
std::unique_ptr<Initializer> conv_B = nullptr;
const bool is_3d = conv_inputs.size() == 3;
if (is_3d) {
if (!graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto))
continue;
if (conv_B_tensor_proto == nullptr)
return Status(ONNXRUNTIME, FAIL, "Internal error in ConvMulFusion. conv_B_tensor_proto is NULL");
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
conv_B_tensor_proto->data_type() != mul_B_tensor_proto->data_type() ||
conv_B_tensor_proto->dims_size() != 1 || (mul_B_tensor_proto->dims_size() != 0 && conv_B_tensor_proto->dims(0) != mul_B_tensor_proto->dims(0))) {
continue;
}
conv_B = std::make_unique<Initializer>(conv_B_tensor_proto);
}
// Calculate new value of initializers of conv node
conv_W->scale_by_axis(*mul_B, 1);
if (conv_inputs.size() == 3) {
if (mul_B_tensor_proto->dims_size() != 0) {
conv_B->mul(*mul_B);
} else {
conv_B->scale_by_axis(*mul_B, 0);
}
}
// Create new initializers of conv
ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto);
conv_W->ToProto(&new_conv_W_tensor_proto);
// Replace initializers of conv node
graph.RemoveInitializedTensor(conv_inputs[1]->Name());
graph.AddInitializedTensor(new_conv_W_tensor_proto);
if (is_3d) {
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*conv_B_tensor_proto);
conv_B->ToProto(&new_conv_B_tensor_proto);
graph.RemoveInitializedTensor(conv_inputs[2]->Name());
graph.AddInitializedTensor(new_conv_B_tensor_proto);
}
// Replace the input of the node following mul node
const NodeArg* mul_output_def = mul_node.OutputDefs()[0];
NodeArg* conv_output_def = conv_node.MutableOutputDefs()[0];
for (auto it = mul_node.OutputNodesBegin(); it != mul_node.OutputNodesEnd(); ++it) {
auto output_node = graph.GetNode((*it).Index());
if (!output_node) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
}
auto& input_defs = output_node->MutableInputDefs();
for (auto& def : input_defs) {
if (def == mul_output_def) {
def = conv_output_def;
}
}
}
removed_nodes.push_back(mul_node.Index());
}
for (auto i : removed_nodes) {
graph.RemoveNode(i);
// Create new initializers of conv
ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto);
conv_W->ToProto(&new_conv_W_tensor_proto);
// Replace initializers of conv node
graph.RemoveInitializedTensor(conv_inputs[1]->Name());
graph.AddInitializedTensor(new_conv_W_tensor_proto);
if (is_3d) {
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*conv_B_tensor_proto);
conv_B->ToProto(&new_conv_B_tensor_proto);
graph.RemoveInitializedTensor(conv_inputs[2]->Name());
graph.AddInitializedTensor(new_conv_B_tensor_proto);
}
if (!removed_nodes.empty()) {
modified = true;
// Remove Mul node.
auto* mul_node_to_remove = graph.GetNode(mul_node.Index());
if (graph_utils::RemoveNode(graph, *mul_node_to_remove)) {
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
}
return Status::OK();
}
bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) ||
node.GetOutputEdgesCount() != 1) {
return false;
}
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", 7) ||
next_node.GetInputEdgesCount() != 1 ||
graph.IsNodeOutputsInGraphOutputs(next_node) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
return false;
}
return true;
}
} // namespace onnxruntime

View file

@ -2,16 +2,29 @@
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
class ConvMulFusion : public onnxruntime::GraphTransformer {
/**
@Class ConvMulFusion
Rewrite rule that fuses two Conv+Mul nodes to a single Conv node.
It is attempted to be triggered only on nodes with op type "Conv".
*/
class ConvMulFusion : public RewriteRule {
public:
ConvMulFusion() noexcept : onnxruntime::GraphTransformer("ConvMulFusion") {}
ConvMulFusion() noexcept : RewriteRule("ConvMulFusion") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Conv"};
}
private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
bool SatisfyCondition(const Graph& graph, const Node& node) override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
};
} // namespace onnxruntime

View file

@ -17,7 +17,7 @@ class GraphTransformerManager {
explicit GraphTransformerManager(unsigned steps) : steps_(steps) {
}
// Register a transformer with a level and compatible providers list
// Register a transformer with a level.
common::Status Register(std::unique_ptr<GraphTransformer> transformer, TransformerLevel level);
// Apply all transformers registered for the given level on the given graph

View file

@ -30,6 +30,9 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(TransformerLevel
break;
case TransformerLevel::Level2:
rules.push_back(std::make_unique<ConvAddFusion>());
rules.push_back(std::make_unique<ConvMulFusion>());
rules.push_back(std::make_unique<ConvBNFusion>());
break;
default:
ORT_ENFORCE(false, "Unsupported level" + std::to_string(static_cast<uint32_t>(level)));
@ -93,9 +96,6 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
transformers.emplace_back(std::make_unique<MatMulAddFusion>(l2_execution_providers));
transformers.emplace_back(std::make_unique<ConvActivationFusion>(l2_execution_providers));
#endif
transformers.emplace_back(std::make_unique<ConvAddFusion>());
transformers.emplace_back(std::make_unique<ConvMulFusion>());
transformers.emplace_back(std::make_unique<ConvBNFusion>());
} break;
default:

View file

@ -10,9 +10,9 @@
namespace onnxruntime {
Status EliminateIdentity::Apply(Graph& graph, Node& node, bool& modified, bool& deleted) {
if (graph_utils::RemoveSingleInputNode(graph, node)) {
modified = deleted = true;
Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
return Status::OK();

View file

@ -25,7 +25,7 @@ class EliminateIdentity : public RewriteRule {
private:
bool SatisfyCondition(const Graph& graph, const Node& node) override;
Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
}; // namespace onnxruntime
} // namespace onnxruntime

View file

@ -3,6 +3,7 @@
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/rewrite_rule.h"
using namespace ::onnxruntime::common;
@ -22,11 +23,11 @@ Status RuleBasedGraphTransformer::Register(std::unique_ptr<RewriteRule> rule) {
Status RuleBasedGraphTransformer::ApplyRulesOnNode(Graph& graph, Node& node,
const std::vector<std::unique_ptr<RewriteRule>>& rules,
bool& modified, bool& deleted) const {
RuleEffect& rule_effect) const {
for (const auto& rule : rules) {
ORT_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, node, modified, deleted));
if (deleted) {
modified = true; // should be set by rewriter but in case it wasn't...
ORT_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, node, rule_effect));
// If the current node was removed as a result of a rule, stop rule application for that node.
if (rule_effect == RuleEffect::kRemovedCurrentNode) {
break;
}
}
@ -39,10 +40,15 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr
for (NodeIndex i : order) {
auto* node = graph.GetNode(i);
// A node might not be found as it might have already been deleted from one of the rules.
if (!node) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
continue;
}
// Initialize the effect of rules on this node to denote that the graph has not yet been modified
// by the rule application on the current node.
auto rule_effect = RuleEffect::kNone;
if (!graph_utils::IsSupportedProvider(*node, GetCompatibleExecutionProviders())) {
continue;
}
@ -50,22 +56,26 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr
// First apply rewrite rules that are registered for the op type of the current node; then apply rules that are
// registered to be applied regardless of the op type; then recursively apply rules to subgraphs (if any).
// Stop further rule application for the current node, if the node gets removed by a rule.
bool deleted = false;
const std::vector<std::unique_ptr<RewriteRule>>* rules = nullptr;
rules = GetRewriteRulesForOpType(node->OpType());
if (rules) {
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, modified, deleted));
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect));
}
if (!deleted) {
if (rule_effect != RuleEffect::kRemovedCurrentNode) {
rules = GetAnyOpRewriteRules();
if (rules) {
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, modified, deleted));
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect));
}
}
if (!deleted) {
// Update the modified field of the rule-based transformer.
if (rule_effect != RuleEffect::kNone) {
modified = true;
}
if (rule_effect != RuleEffect::kRemovedCurrentNode) {
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level));
}
}

View file

@ -8,9 +8,9 @@
namespace onnxruntime {
Status EliminateSlice::Apply(Graph& graph, Node& node, bool& modified, bool& removed) {
if (graph_utils::RemoveSingleInputNode(graph, node)) {
removed = modified = true;
Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
return Status::OK();

View file

@ -25,7 +25,7 @@ class EliminateSlice : public RewriteRule {
private:
bool SatisfyCondition(const Graph& graph, const Node& node) override;
Status Apply(Graph& graph, Node& node, bool& modified, bool& removed) override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
};
} // namespace onnxruntime

View file

@ -10,7 +10,7 @@ using namespace ::onnxruntime::common;
namespace onnxruntime {
Status UnsqueezeElimination::Apply(Graph& graph, Node& node, bool& modified, bool& removed) {
Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
// Get "axes" attribute.
const ONNX_NAMESPACE::AttributeProto* attr = graph_utils::GetNodeAttribute(node, "axes");
if (attr == nullptr || attr->type() != AttributeProto_AttributeType_INTS) {
@ -66,8 +66,8 @@ Status UnsqueezeElimination::Apply(Graph& graph, Node& node, bool& modified, boo
input_def->SetShape(shape);
// Remove Unsqueeze node.
if (graph_utils::RemoveSingleInputNode(graph, node)) {
removed = modified = true;
if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}
return Status::OK();

View file

@ -25,7 +25,7 @@ class UnsqueezeElimination : public RewriteRule {
private:
bool SatisfyCondition(const Graph& graph, const Node& node) override;
Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
};
} // namespace onnxruntime

View file

@ -207,7 +207,7 @@ static void UpdateSubgraphWhenRemovingNode(bool include_nested = false) {
auto& node_to_remove = *graph.GetNode(1);
const auto& if_node = *graph.GetNode(2);
bool removed = graph_utils::RemoveSingleInputNode(graph, node_to_remove);
bool removed = graph_utils::RemoveNode(graph, node_to_remove);
ASSERT_TRUE(removed);
// check subgraph implicit input was updated
@ -238,7 +238,7 @@ static void DontRemoveNodeIfItWillBreakSubgraph(bool test_nested = false) {
auto& graph = model.MainGraph();
auto& node_to_remove = *graph.GetNode(1);
bool removed = graph_utils::RemoveSingleInputNode(graph, node_to_remove);
bool removed = graph_utils::RemoveNode(graph, node_to_remove);
ASSERT_FALSE(removed);
}
@ -287,7 +287,7 @@ TEST(GraphUtils, TestMultiEdgeRemovalNodes) {
ASSERT_EQ(nodes[2]->GetOutputEdgesCount(), 2);
// Remove id_2. This leaves id_0 with 3 output edges. id_0 is now incoming node to id_3 and id_4.
ASSERT_TRUE(graph_utils::RemoveSingleInputNode(graph, *nodes[2]));
ASSERT_TRUE(graph_utils::RemoveNode(graph, *nodes[2]));
ASSERT_EQ(graph.NumberOfNodes(), 4);
ASSERT_EQ(nodes[0]->GetOutputEdgesCount(), 3);
ASSERT_EQ(nodes[3]->InputDefs().size(), 1);
@ -296,7 +296,7 @@ TEST(GraphUtils, TestMultiEdgeRemovalNodes) {
ASSERT_TRUE(nodes[4]->InputDefs()[0]->Name() == "id_0_out");
// Remove id_0
ASSERT_TRUE(graph_utils::RemoveSingleInputNode(graph, *nodes[0]));
ASSERT_TRUE(graph_utils::RemoveNode(graph, *nodes[0]));
ASSERT_EQ(graph.NumberOfNodes(), 3);
ASSERT_TRUE(nodes[1]->InputDefs()[0]->Name() == "id_0_in");
ASSERT_TRUE(nodes[3]->InputDefs()[0]->Name() == "id_0_in");

View file

@ -47,7 +47,7 @@ class DummyRewriteRule : public RewriteRule {
return true;
}
Status Apply(Graph& /*graph*/, Node& /*node*/, bool& /*modified*/, bool& /*deleted*/) override {
Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/) override {
rewrite_rule_invoked_ = true;
return Status::OK();
}

View file

@ -52,10 +52,10 @@ TEST(GraphTransformationTests, IdentityElimination) {
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Identity"] == 1);
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer->Register(std::make_unique<EliminateIdentity>());
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(std::make_unique<EliminateIdentity>());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1);
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
op_to_count = CountOpsInGraph(graph);
@ -70,10 +70,10 @@ TEST(GraphTransformationTests, SliceElimination) {
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Slice"] == 5);
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer->Register(std::make_unique<EliminateSlice>());
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(std::make_unique<EliminateSlice>());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1);
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
op_to_count = CountOpsInGraph(graph);
@ -129,7 +129,9 @@ TEST(GraphTransformationTests, FuseConvBNNoBias) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(std::make_unique<ConvBNFusion>(), TransformerLevel::Level2);
auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
rule_transformer_L2->Register(std::make_unique<ConvBNFusion>());
graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2);
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK());
@ -148,12 +150,15 @@ TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer->Register(std::make_unique<UnsqueezeElimination>());
graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1);
graph_transformation_mgr.Register(std::make_unique<ConvBNFusion>(), TransformerLevel::Level2);
graph_transformation_mgr.Register(std::make_unique<ConvMulFusion>(), TransformerLevel::Level2);
graph_transformation_mgr.Register(std::make_unique<ConvAddFusion>(), TransformerLevel::Level2);
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(std::make_unique<UnsqueezeElimination>());
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
rule_transformer_L2->Register(std::make_unique<ConvAddFusion>());
rule_transformer_L2->Register(std::make_unique<ConvBNFusion>());
rule_transformer_L2->Register(std::make_unique<ConvMulFusion>());
graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2);
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK());
@ -201,10 +206,13 @@ TEST(GraphTransformationTests, FuseConvMulNoBias) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer->Register(std::make_unique<UnsqueezeElimination>());
graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1);
graph_transformation_mgr.Register(std::make_unique<ConvMulFusion>(), TransformerLevel::Level2);
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(std::make_unique<UnsqueezeElimination>());
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
rule_transformer_L2->Register(std::make_unique<ConvMulFusion>());
graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2);
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK());
@ -222,10 +230,13 @@ TEST(GraphTransformationTests, FuseConvAddNoBias) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer->Register(std::make_unique<UnsqueezeElimination>());
graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1);
graph_transformation_mgr.Register(std::make_unique<ConvAddFusion>(), TransformerLevel::Level2);
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(std::make_unique<UnsqueezeElimination>());
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
rule_transformer_L2->Register(std::make_unique<ConvAddFusion>());
graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2);
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK());
@ -243,8 +254,10 @@ TEST(GraphTransformationTests, FuseConvAddMul3D) {
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(std::make_unique<ConvMulFusion>(), TransformerLevel::Level2);
graph_transformation_mgr.Register(std::make_unique<ConvAddFusion>(), TransformerLevel::Level2);
auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
rule_transformer_L2->Register(std::make_unique<ConvAddFusion>());
rule_transformer_L2->Register(std::make_unique<ConvMulFusion>());
graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2);
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK());
@ -315,12 +328,11 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) {
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
std::unique_ptr<ConvBNFusion> ConvBNFusion_transformer = std::make_unique<ConvBNFusion>();
std::unique_ptr<ConvMulFusion> ConvMulFusion_transformer = std::make_unique<ConvMulFusion>();
std::unique_ptr<ConvAddFusion> ConvAddFusion_transformer = std::make_unique<ConvAddFusion>();
session_object.RegisterGraphTransformer(std::move(ConvBNFusion_transformer));
session_object.RegisterGraphTransformer(std::move(ConvMulFusion_transformer));
session_object.RegisterGraphTransformer(std::move(ConvAddFusion_transformer));
auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
rule_transformer_L2->Register(std::make_unique<ConvAddFusion>());
rule_transformer_L2->Register(std::make_unique<ConvBNFusion>());
rule_transformer_L2->Register(std::make_unique<ConvMulFusion>());
session_object.RegisterGraphTransformer(std::move(rule_transformer_L2), TransformerLevel::Level2);
ASSERT_TRUE(session_object.Initialize().IsOK());
@ -335,7 +347,8 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) {
for (int i = 0; i < 9; ++i) {
values_x.push_back(x_f);
}
CreateMLValue<MLFloat16>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_x, values_x, &ml_value_x);
CreateMLValue<MLFloat16>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault),
dims_x, values_x, &ml_value_x);
feeds.insert(std::make_pair("X", ml_value_x));
std::vector<std::string> output_names;
@ -355,7 +368,8 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) {
auto& rtensor = fetches.front().Get<Tensor>();
TensorShape expected_shape(expected_dims_prod);
ASSERT_EQ(expected_shape, rtensor.Shape());
const std::vector<MLFloat16> found(rtensor.template Data<MLFloat16>(), rtensor.template Data<MLFloat16>() + expected_dims_prod.size());
const std::vector<MLFloat16> found(rtensor.template Data<MLFloat16>(),
rtensor.template Data<MLFloat16>() + expected_dims_prod.size());
ASSERT_EQ(expected_values_prod, found);
}

View file

@ -38,7 +38,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) {
ASSERT_TRUE(transformers.size() != 0);
// Transformer name match test
std::vector<std::string> custom_list = {"EliminateIdentity", "ConvAddFusion", "ConvMulFusion", "abc", "def"};
std::vector<std::string> custom_list = {"EliminateIdentity", "GemmActivationFusion", "MatMulAddFusion", "abc", "def"};
transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level2, custom_list);
ASSERT_TRUE(transformers.size() == 2);
// validate each rule returned is present in the custom list
@ -46,7 +46,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) {
ASSERT_TRUE(std::find(custom_list.begin(), custom_list.end(), transformer->Name()) != custom_list.end());
}
// Transformer name no match test. When there is no match empty list is expected.
// Transformer name no-match test. When there is no match, empty list is expected.
custom_list = {"EliminateIdentity"};
transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level2, custom_list);
ASSERT_TRUE(transformers.size() == 0);
@ -56,7 +56,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers_CustomList) {
// custom list of rules and transformers
std::string l1_rule1 = "EliminateIdentity";
std::string l1_transformer = "ConstantFolding";
std::string l2_transformer = "ConvAddFusion";
std::string l2_transformer = "GemmActivationFusion";
std::vector<std::string> custom_list = {l1_rule1, l1_transformer, l2_transformer};
auto transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level1, custom_list);