mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Conv(Add|Mul|BN)Fusion as rewrite rules (#863)
* Converted ConvAddFusion, ConvMulFusion, and ConvBNFusion to rewrite rules * Extended graph_utils::RemoveNode * Introduced RewriteRuleEffect enum
This commit is contained in:
parent
0ad940027c
commit
feab3088fb
23 changed files with 489 additions and 490 deletions
|
|
@ -35,6 +35,18 @@ If the list of op types is left empty, that rule will be triggered for every op
|
|||
*/
|
||||
class RewriteRule {
|
||||
public:
|
||||
/**
|
||||
@class RewriteRuleEffect
|
||||
|
||||
Class used to indicate the effect of rule application on a graph's node.
|
||||
*/
|
||||
enum class RewriteRuleEffect : uint8_t {
|
||||
kNone, // The rewrite rule has not modified the graph.
|
||||
kUpdatedCurrentNode, // The rewrite rule updated (but did not remove) the node on which it was triggered.
|
||||
kRemovedCurrentNode, // The rewrite rule removed the node on which it was triggered.
|
||||
kModifiedRestOfGraph // The rewrite rule modified nodes other than the one it was triggered on.
|
||||
};
|
||||
|
||||
RewriteRule(const std::string& name) : name_(name) {}
|
||||
|
||||
virtual ~RewriteRule() = default;
|
||||
|
|
@ -52,11 +64,10 @@ class RewriteRule {
|
|||
/** Checks if the condition of the rule is satisfied, and if so applies the body of the rule.
|
||||
@param[in] graph The Graph.
|
||||
@param[in] node The Node to apply the rewrite to.
|
||||
@param[out] modified Set to indicate whether the node was modified or not.
|
||||
@param[out] deleted Set to indicate if the node was deleted.
|
||||
@param[out] rule_effect Enum to indicate if and how the graph was modified as a result of the rule application.
|
||||
@returns Status indicating success or providing error information */
|
||||
common::Status CheckConditionAndApply(Graph& graph, Node& node, bool& modified, bool& deleted) {
|
||||
return SatisfyCondition(graph, node) ? Apply(graph, node, modified, deleted) : Status::OK();
|
||||
common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
|
||||
return SatisfyCondition(graph, node) ? Apply(graph, node, rule_effect) : Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -72,8 +83,7 @@ class RewriteRule {
|
|||
|
||||
/** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place.
|
||||
The return-value of node may be different from the input-value due to rewriting.
|
||||
The value of "modified" indicates if the graph was modified or not.
|
||||
The value of "deleted" indicates if the node was deleted or not. */
|
||||
virtual common::Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) = 0;
|
||||
The value of "rule_effect" indicates whether and how the graph was modified by the rule. */
|
||||
virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) = 0;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -58,14 +58,16 @@ class RuleBasedGraphTransformer : public GraphTransformer {
|
|||
@param[in] graph The Graph.
|
||||
@param[in] node The Node to apply the rules to.
|
||||
@param[in] rules The vector of RewriteRules that will be applied to the Node.
|
||||
@param[out] modified Set to indicate whether the node was modified or not.
|
||||
@param[out] deleted Set to indicate if the node was deleted.
|
||||
@param[out] rule_effect Enum that indicates whether and how the graph was modified as a result of
|
||||
applying rules on this node.
|
||||
@returns Status indicating success or providing error information. */
|
||||
common::Status ApplyRulesOnNode(Graph& graph, Node& node,
|
||||
const std::vector<std::unique_ptr<RewriteRule>>& rules,
|
||||
bool& modified, bool& deleted) const;
|
||||
RewriteRule::RewriteRuleEffect& rule_effect) const;
|
||||
|
||||
private:
|
||||
using RuleEffect = RewriteRule::RewriteRuleEffect;
|
||||
|
||||
// Map that associates a node's op type with the vector of rules that are registered to be triggered for that node.
|
||||
std::unordered_map<std::string, std::vector<std::unique_ptr<RewriteRule>>> op_type_to_rules_;
|
||||
// Rules that will be evaluated regardless of the op type of the node.
|
||||
|
|
|
|||
|
|
@ -353,19 +353,24 @@ const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const s
|
|||
return iter == attrs.end() ? nullptr : &iter->second;
|
||||
}
|
||||
|
||||
bool RemoveSingleInputNode(Graph& graph, Node& node) {
|
||||
// Cannot remove a node with multiple output NodeArgs (multiple output edges is fine), neither
|
||||
// a node whose output is also a graph output.
|
||||
if (!IsSingleInSingleOutNode(node) ||
|
||||
bool RemoveNode(Graph& graph, Node& node) {
|
||||
// Cannot remove a node with implicit inputs, with multiple output NodeArgs (multiple output edges is fine),
|
||||
// or whose output is also a graph output.
|
||||
if (node.ImplicitInputDefs().size() > 0 ||
|
||||
node.OutputDefs().size() != 1 ||
|
||||
graph.IsNodeOutputsInGraphOutputs(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If the single input comes from another node (initializers are not connected with edges to nodes).
|
||||
if (node.GetInputEdgesCount() == 1) {
|
||||
// If there is a single input edge from another node (initializers are not connected with edges to nodes).
|
||||
return RemoveNodeWithSingleNodeIn(graph, node);
|
||||
} else {
|
||||
} else if (node.InputDefs().size() == 1) {
|
||||
// If a single initializer is the only input.
|
||||
return RemoveNodeWithSingleInitializerIn(graph, node);
|
||||
} else {
|
||||
// No other node removal is supported, because there will be no way to connect its inputs to its outputs.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -64,10 +64,14 @@ bool GetRepeatedNodeAttributeValues(const Node& node,
|
|||
Status ForAllMutableSubgraphs(Graph& main_graph, std::function<Status(Graph&)> func);
|
||||
Status ForAllSubgraphs(const Graph& main_graph, std::function<Status(const Graph&)> func);
|
||||
|
||||
/** Removes the given single-input Node from the Graph. The single input might be either
|
||||
another node or an initializer, but not an implicit input. The node should have a single
|
||||
output but can have multiple output edges. */
|
||||
bool RemoveSingleInputNode(Graph& graph, Node& node);
|
||||
/** Removes the given Node from the Graph and keeps Graph consistent by rebuilding needed connections.
|
||||
We support the removal of the Node if it has no implicit inputs and a single output (but it can have multiple
|
||||
output edges). As for the Node's inputs, we support the following cases:
|
||||
- If the Node has a single incoming node (and possibly multiple initializers), we can remove the Node and
|
||||
connect its incoming node to its outgoing nodes.
|
||||
- If the Node has a single initializer as input, we remove the Node and feed the initializer as input to its
|
||||
output nodes. */
|
||||
bool RemoveNode(Graph& graph, Node& node);
|
||||
|
||||
/** Removes all output edges from the given Node of the Graph.
|
||||
This should probably be elevated to the Graph API eventually. */
|
||||
|
|
|
|||
|
|
@ -9,136 +9,108 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
Status ConvAddFusion::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const {
|
||||
std::vector<onnxruntime::NodeIndex> removed_nodes;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
|
||||
Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modified) {
|
||||
auto& conv_node = node;
|
||||
const auto& add_node = *conv_node.OutputNodesBegin();
|
||||
const auto& conv_inputs = conv_node.InputDefs();
|
||||
const auto& add_inputs = add_node.InputDefs();
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
}
|
||||
const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
|
||||
|
||||
const Node& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", 7) ||
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType() ||
|
||||
next_node.GetInputEdgesCount() != 1 ||
|
||||
graph.IsNodeOutputsInGraphOutputs(next_node)) {
|
||||
continue;
|
||||
}
|
||||
const ONNX_NAMESPACE::TensorProto* add_B_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(add_inputs[1]->Name(), add_B_tensor_proto);
|
||||
|
||||
auto& conv_node = node;
|
||||
const Node& add_node = next_node;
|
||||
|
||||
const auto& conv_inputs = conv_node.InputDefs();
|
||||
const auto& add_inputs = add_node.InputDefs();
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* add_B_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(add_inputs[1]->Name(), add_B_tensor_proto);
|
||||
|
||||
// Currently, fusion is only supported for float or double data type.
|
||||
if (!Initializer::IsSupportedDataType(add_B_tensor_proto) ||
|
||||
conv_W_tensor_proto->dims_size() < 4 ||
|
||||
add_B_tensor_proto->dims_size() != conv_W_tensor_proto->dims_size() - 1 ||
|
||||
conv_W_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// The dimensions of add_B should be equal to 1 except first dimension.
|
||||
bool flag = false;
|
||||
for (int i = 1; i < add_B_tensor_proto->dims_size(); i++) {
|
||||
if (add_B_tensor_proto->dims(i) != 1) {
|
||||
flag = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (flag) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr;
|
||||
if (conv_inputs.size() == 3) {
|
||||
graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto);
|
||||
|
||||
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
|
||||
conv_B_tensor_proto->data_type() != add_B_tensor_proto->data_type() ||
|
||||
conv_B_tensor_proto->dims_size() != 1 ||
|
||||
conv_B_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto conv_B = std::make_unique<Initializer>(conv_B_tensor_proto);
|
||||
auto add_B = std::make_unique<Initializer>(add_B_tensor_proto);
|
||||
|
||||
if (conv_B->size() != add_B->size()) {
|
||||
continue;
|
||||
}
|
||||
// Calculate new value of initializers of conv node
|
||||
conv_B->add(*add_B);
|
||||
|
||||
// Create new initializers of conv
|
||||
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto;
|
||||
conv_B->ToProto(&new_conv_B_tensor_proto);
|
||||
|
||||
// Replace initializers of conv node
|
||||
graph.RemoveInitializedTensor(conv_inputs[2]->Name());
|
||||
graph.AddInitializedTensor(new_conv_B_tensor_proto);
|
||||
} else {
|
||||
NodeArg* add_B_node_arg = graph.GetNodeArg(add_B_tensor_proto->name());
|
||||
if (add_B_node_arg == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Update shape of tensor proto
|
||||
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*add_B_tensor_proto);
|
||||
int64_t dim = conv_W_tensor_proto->dims(0);
|
||||
new_conv_B_tensor_proto.clear_dims();
|
||||
new_conv_B_tensor_proto.add_dims(dim);
|
||||
|
||||
graph.RemoveInitializedTensor(add_B_tensor_proto->name());
|
||||
graph.AddInitializedTensor(new_conv_B_tensor_proto);
|
||||
|
||||
// Update shape of NodeArg
|
||||
TensorShapeProto shape;
|
||||
shape.add_dim()->set_dim_value(dim);
|
||||
add_B_node_arg->SetShape(shape);
|
||||
|
||||
conv_node.MutableInputDefs().push_back(add_B_node_arg);
|
||||
conv_node.MutableInputArgsCount()[2] = 1;
|
||||
}
|
||||
|
||||
// Replace the input of the node following add node
|
||||
const NodeArg* add_output_def = add_node.OutputDefs()[0];
|
||||
NodeArg* conv_output_def = conv_node.MutableOutputDefs()[0];
|
||||
for (auto it = add_node.OutputNodesBegin(); it != add_node.OutputNodesEnd(); ++it) {
|
||||
auto output_node = graph.GetNode((*it).Index());
|
||||
if (!output_node) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
|
||||
}
|
||||
auto& input_defs = output_node->MutableInputDefs();
|
||||
for (auto& def : input_defs) {
|
||||
if (def == add_output_def) {
|
||||
def = conv_output_def;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
removed_nodes.push_back(add_node.Index());
|
||||
// Currently, fusion is only supported for float or double data type.
|
||||
if (!Initializer::IsSupportedDataType(add_B_tensor_proto) ||
|
||||
conv_W_tensor_proto->dims_size() < 4 ||
|
||||
add_B_tensor_proto->dims_size() != conv_W_tensor_proto->dims_size() - 1 ||
|
||||
conv_W_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
for (auto i : removed_nodes) {
|
||||
graph.RemoveNode(i);
|
||||
// The dimensions of add_B should be equal to 1 except first dimension.
|
||||
for (int i = 1; i < add_B_tensor_proto->dims_size(); i++) {
|
||||
if (add_B_tensor_proto->dims(i) != 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
if (!removed_nodes.empty()) {
|
||||
modified = true;
|
||||
const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr;
|
||||
if (conv_inputs.size() == 3) {
|
||||
graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto);
|
||||
|
||||
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
|
||||
conv_B_tensor_proto->data_type() != add_B_tensor_proto->data_type() ||
|
||||
conv_B_tensor_proto->dims_size() != 1 ||
|
||||
conv_B_tensor_proto->dims(0) != add_B_tensor_proto->dims(0)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto conv_B = std::make_unique<Initializer>(conv_B_tensor_proto);
|
||||
auto add_B = std::make_unique<Initializer>(add_B_tensor_proto);
|
||||
|
||||
if (conv_B->size() != add_B->size()) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Calculate new value of initializers of conv node
|
||||
conv_B->add(*add_B);
|
||||
|
||||
// Create new initializers of conv
|
||||
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto;
|
||||
conv_B->ToProto(&new_conv_B_tensor_proto);
|
||||
|
||||
// Replace initializers of conv node
|
||||
graph.RemoveInitializedTensor(conv_inputs[2]->Name());
|
||||
graph.AddInitializedTensor(new_conv_B_tensor_proto);
|
||||
} else {
|
||||
NodeArg* add_B_node_arg = graph.GetNodeArg(add_B_tensor_proto->name());
|
||||
if (add_B_node_arg == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Update shape of tensor proto
|
||||
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*add_B_tensor_proto);
|
||||
int64_t dim = conv_W_tensor_proto->dims(0);
|
||||
new_conv_B_tensor_proto.clear_dims();
|
||||
new_conv_B_tensor_proto.add_dims(dim);
|
||||
|
||||
graph.RemoveInitializedTensor(add_B_tensor_proto->name());
|
||||
graph.AddInitializedTensor(new_conv_B_tensor_proto);
|
||||
|
||||
// Update shape of NodeArg
|
||||
TensorShapeProto shape;
|
||||
shape.add_dim()->set_dim_value(dim);
|
||||
add_B_node_arg->SetShape(shape);
|
||||
|
||||
conv_node.MutableInputDefs().push_back(add_B_node_arg);
|
||||
conv_node.MutableInputArgsCount()[2] = 1;
|
||||
}
|
||||
|
||||
// Remove Add node.
|
||||
auto* add_node_to_remove = graph.GetNode(add_node.Index());
|
||||
if (graph_utils::RemoveNode(graph, *add_node_to_remove)) {
|
||||
modified = RewriteRuleEffect::kModifiedRestOfGraph;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", 7) ||
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType() ||
|
||||
next_node.GetInputEdgesCount() != 1 ||
|
||||
graph.IsNodeOutputsInGraphOutputs(next_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -3,16 +3,29 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class ConvAddFusion : public onnxruntime::GraphTransformer {
|
||||
/**
|
||||
@Class ConvAddFusion
|
||||
|
||||
Rewrite rule that fuses two Conv+Add nodes to a single Conv node.
|
||||
|
||||
It is attempted to be triggered only on nodes with op type "Conv".
|
||||
*/
|
||||
class ConvAddFusion : public RewriteRule {
|
||||
public:
|
||||
ConvAddFusion() noexcept : onnxruntime::GraphTransformer("ConvAddFusion") {}
|
||||
ConvAddFusion() noexcept : RewriteRule("ConvAddFusion") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"Conv"};
|
||||
}
|
||||
|
||||
private:
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -9,184 +9,154 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
Status ConvBNFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
std::vector<onnxruntime::NodeIndex> removed_nodes;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
|
||||
Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
|
||||
auto& conv_node = node;
|
||||
const Node& bn_node = *conv_node.OutputNodesBegin();
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
// Get value of attribute epsilon
|
||||
const onnxruntime::NodeAttributes& attributes = bn_node.GetAttributes();
|
||||
const ONNX_NAMESPACE::AttributeProto* attr = &(attributes.find("epsilon")->second);
|
||||
if (attr == nullptr || attr->type() != AttributeProto_AttributeType_FLOAT) {
|
||||
return Status::OK();
|
||||
}
|
||||
float epsilon = static_cast<float>(attr->f());
|
||||
|
||||
// Get initializers of BatchNormalization
|
||||
const auto& bn_inputs = bn_node.InputDefs();
|
||||
const ONNX_NAMESPACE::TensorProto* bn_scale_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(bn_inputs[1]->Name(), bn_scale_tensor_proto);
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* bn_B_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(bn_inputs[2]->Name(), bn_B_tensor_proto);
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* bn_mean_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(bn_inputs[3]->Name(), bn_mean_tensor_proto);
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* bn_var_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(bn_inputs[4]->Name(), bn_var_tensor_proto);
|
||||
|
||||
const auto& conv_inputs = conv_node.InputDefs();
|
||||
const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
|
||||
|
||||
// Currently, fusion is only supported for float or double data type.
|
||||
if (!Initializer::IsSupportedDataType(bn_scale_tensor_proto) ||
|
||||
!Initializer::IsSupportedDataType(bn_B_tensor_proto) ||
|
||||
!Initializer::IsSupportedDataType(bn_mean_tensor_proto) ||
|
||||
!Initializer::IsSupportedDataType(bn_var_tensor_proto) ||
|
||||
!Initializer::IsSupportedDataType(conv_W_tensor_proto) ||
|
||||
bn_scale_tensor_proto->dims_size() != 1 ||
|
||||
bn_B_tensor_proto->dims_size() != 1 ||
|
||||
bn_mean_tensor_proto->dims_size() != 1 ||
|
||||
bn_var_tensor_proto->dims_size() != 1 ||
|
||||
bn_scale_tensor_proto->dims(0) != bn_B_tensor_proto->dims(0) ||
|
||||
bn_B_tensor_proto->dims(0) != bn_mean_tensor_proto->dims(0) ||
|
||||
bn_mean_tensor_proto->dims(0) != bn_var_tensor_proto->dims(0) ||
|
||||
bn_scale_tensor_proto->data_type() != bn_B_tensor_proto->data_type() ||
|
||||
bn_B_tensor_proto->data_type() != bn_mean_tensor_proto->data_type() ||
|
||||
bn_mean_tensor_proto->data_type() != bn_var_tensor_proto->data_type() ||
|
||||
conv_W_tensor_proto->data_type() != bn_scale_tensor_proto->data_type() ||
|
||||
!(conv_W_tensor_proto->dims_size() > 2 && conv_W_tensor_proto->dims(0) == bn_scale_tensor_proto->dims(0))) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto bn_scale = std::make_unique<Initializer>(bn_scale_tensor_proto);
|
||||
auto bn_B = std::make_unique<Initializer>(bn_B_tensor_proto);
|
||||
auto bn_mean = std::make_unique<Initializer>(bn_mean_tensor_proto);
|
||||
auto bn_var = std::make_unique<Initializer>(bn_var_tensor_proto);
|
||||
auto conv_W = std::make_unique<Initializer>(conv_W_tensor_proto);
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr;
|
||||
std::unique_ptr<Initializer> conv_B = nullptr;
|
||||
if (conv_inputs.size() == 3) {
|
||||
if (!graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const Node& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BatchNormalization", 7) ||
|
||||
next_node.GetInputEdgesCount() != 1 ||
|
||||
graph.IsNodeOutputsInGraphOutputs(next_node) ||
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
continue;
|
||||
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
|
||||
conv_B_tensor_proto->dims_size() != 1 ||
|
||||
conv_B_tensor_proto->dims(0) != bn_B_tensor_proto->dims(0) ||
|
||||
conv_B_tensor_proto->data_type() != bn_B_tensor_proto->data_type()) {
|
||||
return Status::OK();
|
||||
}
|
||||
conv_B = std::make_unique<Initializer>(conv_B_tensor_proto);
|
||||
}
|
||||
|
||||
auto& conv_node = node;
|
||||
const Node& bn_node = next_node;
|
||||
// Calculate new value of initializers of conv node
|
||||
bn_var->add(epsilon);
|
||||
bn_var->sqrt();
|
||||
bn_scale->div(*bn_var);
|
||||
conv_W->scale_by_axis(*bn_scale, 1);
|
||||
|
||||
// Get value of attribute group
|
||||
const onnxruntime::NodeAttributes& conv_attributes = conv_node.GetAttributes();
|
||||
const ONNX_NAMESPACE::AttributeProto* group_attr = &(conv_attributes.find("group")->second);
|
||||
if (group_attr != nullptr &&
|
||||
group_attr->type() == AttributeProto_AttributeType_INT &&
|
||||
group_attr->has_i() && group_attr->i() != 1) {
|
||||
continue;
|
||||
if (conv_inputs.size() == 3) {
|
||||
conv_B->sub(*bn_mean);
|
||||
conv_B->mul(*bn_scale);
|
||||
conv_B->add(*bn_B);
|
||||
} else {
|
||||
bn_mean->mul(*bn_scale);
|
||||
bn_B->sub(*bn_mean);
|
||||
}
|
||||
|
||||
// Create new initializers of conv
|
||||
ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto);
|
||||
conv_W->ToProto(&new_conv_W_tensor_proto);
|
||||
|
||||
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto;
|
||||
NodeArg* bn_B_node_arg = nullptr;
|
||||
if (conv_inputs.size() == 3) {
|
||||
conv_B->ToProto(&new_conv_B_tensor_proto);
|
||||
} else {
|
||||
bn_B->ToProto(&new_conv_B_tensor_proto);
|
||||
bn_B_node_arg = graph.GetNodeArg(bn_B_tensor_proto->name());
|
||||
if (bn_B_node_arg == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// Get value of attribute epsilon
|
||||
const onnxruntime::NodeAttributes& attributes = bn_node.GetAttributes();
|
||||
const ONNX_NAMESPACE::AttributeProto* attr = &(attributes.find("epsilon")->second);
|
||||
if (attr == nullptr || attr->type() != AttributeProto_AttributeType_FLOAT) {
|
||||
continue;
|
||||
}
|
||||
float epsilon = static_cast<float>(attr->f());
|
||||
|
||||
// Get initializers of BatchNormalization
|
||||
const auto& bn_inputs = bn_node.InputDefs();
|
||||
const ONNX_NAMESPACE::TensorProto* bn_scale_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(bn_inputs[1]->Name(), bn_scale_tensor_proto);
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* bn_B_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(bn_inputs[2]->Name(), bn_B_tensor_proto);
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* bn_mean_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(bn_inputs[3]->Name(), bn_mean_tensor_proto);
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* bn_var_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(bn_inputs[4]->Name(), bn_var_tensor_proto);
|
||||
|
||||
const auto& conv_inputs = conv_node.InputDefs();
|
||||
const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
|
||||
|
||||
// Currently, fusion is only supported for float or double data type.
|
||||
if (!Initializer::IsSupportedDataType(bn_scale_tensor_proto) ||
|
||||
!Initializer::IsSupportedDataType(bn_B_tensor_proto) ||
|
||||
!Initializer::IsSupportedDataType(bn_mean_tensor_proto) ||
|
||||
!Initializer::IsSupportedDataType(bn_var_tensor_proto) ||
|
||||
!Initializer::IsSupportedDataType(conv_W_tensor_proto) ||
|
||||
bn_scale_tensor_proto->dims_size() != 1 ||
|
||||
bn_B_tensor_proto->dims_size() != 1 ||
|
||||
bn_mean_tensor_proto->dims_size() != 1 ||
|
||||
bn_var_tensor_proto->dims_size() != 1 ||
|
||||
bn_scale_tensor_proto->dims(0) != bn_B_tensor_proto->dims(0) ||
|
||||
bn_B_tensor_proto->dims(0) != bn_mean_tensor_proto->dims(0) ||
|
||||
bn_mean_tensor_proto->dims(0) != bn_var_tensor_proto->dims(0) ||
|
||||
bn_scale_tensor_proto->data_type() != bn_B_tensor_proto->data_type() ||
|
||||
bn_B_tensor_proto->data_type() != bn_mean_tensor_proto->data_type() ||
|
||||
bn_mean_tensor_proto->data_type() != bn_var_tensor_proto->data_type() ||
|
||||
conv_W_tensor_proto->data_type() != bn_scale_tensor_proto->data_type() ||
|
||||
!(conv_W_tensor_proto->dims_size() > 2 && conv_W_tensor_proto->dims(0) == bn_scale_tensor_proto->dims(0))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto bn_scale = std::make_unique<Initializer>(bn_scale_tensor_proto);
|
||||
auto bn_B = std::make_unique<Initializer>(bn_B_tensor_proto);
|
||||
auto bn_mean = std::make_unique<Initializer>(bn_mean_tensor_proto);
|
||||
auto bn_var = std::make_unique<Initializer>(bn_var_tensor_proto);
|
||||
auto conv_W = std::make_unique<Initializer>(conv_W_tensor_proto);
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr;
|
||||
std::unique_ptr<Initializer> conv_B = nullptr;
|
||||
if (conv_inputs.size() == 3) {
|
||||
if (!graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto))
|
||||
continue;
|
||||
|
||||
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
|
||||
conv_B_tensor_proto->dims_size() != 1 ||
|
||||
conv_B_tensor_proto->dims(0) != bn_B_tensor_proto->dims(0) ||
|
||||
conv_B_tensor_proto->data_type() != bn_B_tensor_proto->data_type()) {
|
||||
continue;
|
||||
}
|
||||
conv_B = std::make_unique<Initializer>(conv_B_tensor_proto);
|
||||
}
|
||||
|
||||
// Calculate new value of initializers of conv node
|
||||
bn_var->add(epsilon);
|
||||
bn_var->sqrt();
|
||||
bn_scale->div(*bn_var);
|
||||
conv_W->scale_by_axis(*bn_scale, 1);
|
||||
|
||||
if (conv_inputs.size() == 3) {
|
||||
conv_B->sub(*bn_mean);
|
||||
conv_B->mul(*bn_scale);
|
||||
conv_B->add(*bn_B);
|
||||
} else {
|
||||
bn_mean->mul(*bn_scale);
|
||||
bn_B->sub(*bn_mean);
|
||||
}
|
||||
|
||||
// Create new initializers of conv
|
||||
ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto);
|
||||
conv_W->ToProto(&new_conv_W_tensor_proto);
|
||||
|
||||
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto;
|
||||
NodeArg* bn_B_node_arg = nullptr;
|
||||
if (conv_inputs.size() == 3) {
|
||||
conv_B->ToProto(&new_conv_B_tensor_proto);
|
||||
} else {
|
||||
bn_B->ToProto(&new_conv_B_tensor_proto);
|
||||
bn_B_node_arg = graph.GetNodeArg(bn_B_tensor_proto->name());
|
||||
if (bn_B_node_arg == nullptr) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Replace initializers of conv node
|
||||
graph.RemoveInitializedTensor(conv_W_tensor_proto->name());
|
||||
if (conv_inputs.size() == 3) {
|
||||
// Replace initializers of conv node
|
||||
graph.RemoveInitializedTensor(conv_W_tensor_proto->name());
|
||||
if (conv_inputs.size() == 3) {
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 6011) // Not deferencing null pointer. conv_B_tensor_proto is set on line 93
|
||||
#endif
|
||||
graph.RemoveInitializedTensor(conv_B_tensor_proto->name());
|
||||
graph.RemoveInitializedTensor(conv_B_tensor_proto->name());
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
|
||||
} else {
|
||||
graph.RemoveInitializedTensor(bn_B_tensor_proto->name());
|
||||
conv_node.MutableInputDefs().push_back(bn_B_node_arg);
|
||||
conv_node.MutableInputArgsCount()[2] = 1;
|
||||
}
|
||||
graph.AddInitializedTensor(new_conv_W_tensor_proto);
|
||||
graph.AddInitializedTensor(new_conv_B_tensor_proto);
|
||||
|
||||
// Replace the input of the nodes following batch normalization node
|
||||
const NodeArg* bn_output_def = bn_node.OutputDefs()[0];
|
||||
NodeArg* conv_output_def = conv_node.MutableOutputDefs()[0];
|
||||
for (auto it = bn_node.OutputNodesBegin(); it != bn_node.OutputNodesEnd(); ++it) {
|
||||
auto output_node = graph.GetNode((*it).Index());
|
||||
if (!output_node) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
auto& input_defs = output_node->MutableInputDefs();
|
||||
for (auto& def : input_defs) {
|
||||
if (def == bn_output_def) {
|
||||
def = conv_output_def;
|
||||
}
|
||||
}
|
||||
}
|
||||
removed_nodes.push_back(bn_node.Index());
|
||||
} else {
|
||||
graph.RemoveInitializedTensor(bn_B_tensor_proto->name());
|
||||
conv_node.MutableInputDefs().push_back(bn_B_node_arg);
|
||||
conv_node.MutableInputArgsCount()[2] = 1;
|
||||
}
|
||||
graph.AddInitializedTensor(new_conv_W_tensor_proto);
|
||||
graph.AddInitializedTensor(new_conv_B_tensor_proto);
|
||||
|
||||
for (auto i : removed_nodes) {
|
||||
graph.RemoveNode(i);
|
||||
}
|
||||
|
||||
if (!removed_nodes.empty()) {
|
||||
modified = true;
|
||||
// Remove BN node.
|
||||
auto* bn_node_to_remove = graph.GetNode(bn_node.Index());
|
||||
if (graph_utils::RemoveNode(graph, *bn_node_to_remove)) {
|
||||
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BatchNormalization", 7) ||
|
||||
next_node.GetInputEdgesCount() != 1 ||
|
||||
graph.IsNodeOutputsInGraphOutputs(next_node) ||
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -3,15 +3,29 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class ConvBNFusion : public onnxruntime::GraphTransformer {
|
||||
/**
|
||||
@Class ConvBNFusion
|
||||
|
||||
Rewrite rule that fuses two Conv+BN nodes to a single Conv node.
|
||||
|
||||
It is attempted to be triggered only on nodes with op type "Conv".
|
||||
*/
|
||||
class ConvBNFusion : public RewriteRule {
|
||||
public:
|
||||
ConvBNFusion() noexcept : onnxruntime::GraphTransformer("ConvBNFusion") {}
|
||||
ConvBNFusion() noexcept : RewriteRule("ConvBNFusion") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"Conv"};
|
||||
}
|
||||
|
||||
private:
|
||||
Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level) const override;
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -9,135 +9,107 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
Status ConvMulFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
std::vector<onnxruntime::NodeIndex> removed_nodes;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
|
||||
auto& conv_node = node;
|
||||
const auto& mul_node = *conv_node.OutputNodesBegin();
|
||||
const auto& conv_inputs = conv_node.InputDefs();
|
||||
const auto& mul_inputs = mul_node.InputDefs();
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* mul_B_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(mul_inputs[1]->Name(), mul_B_tensor_proto);
|
||||
|
||||
if (!Initializer::IsSupportedDataType(conv_W_tensor_proto) ||
|
||||
!Initializer::IsSupportedDataType(mul_B_tensor_proto) ||
|
||||
conv_W_tensor_proto->data_type() != mul_B_tensor_proto->data_type() ||
|
||||
conv_W_tensor_proto->dims_size() < 4 ||
|
||||
!(mul_B_tensor_proto->dims_size() == 0 ||
|
||||
(mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1 &&
|
||||
conv_W_tensor_proto->dims(0) == mul_B_tensor_proto->dims(0)))) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The dimensions of mul_B should be equal to 1 except first dimension.
|
||||
if (mul_B_tensor_proto->dims_size() != 0) {
|
||||
for (int i = 1; i < mul_B_tensor_proto->dims_size(); i++) {
|
||||
if (mul_B_tensor_proto->dims(i) != 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const Node& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", 7) ||
|
||||
next_node.GetInputEdgesCount() != 1 ||
|
||||
graph.IsNodeOutputsInGraphOutputs(next_node) ||
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
continue;
|
||||
auto conv_W = std::make_unique<Initializer>(conv_W_tensor_proto);
|
||||
auto mul_B = std::make_unique<Initializer>(mul_B_tensor_proto);
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr;
|
||||
std::unique_ptr<Initializer> conv_B = nullptr;
|
||||
const bool is_3d = conv_inputs.size() == 3;
|
||||
if (is_3d) {
|
||||
if (!graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto))
|
||||
return Status::OK();
|
||||
if (conv_B_tensor_proto == nullptr)
|
||||
return Status(ONNXRUNTIME, FAIL, "Internal error in ConvMulFusion. conv_B_tensor_proto is NULL");
|
||||
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
|
||||
conv_B_tensor_proto->data_type() != mul_B_tensor_proto->data_type() ||
|
||||
conv_B_tensor_proto->dims_size() != 1 ||
|
||||
(mul_B_tensor_proto->dims_size() != 0 && conv_B_tensor_proto->dims(0) != mul_B_tensor_proto->dims(0))) {
|
||||
return Status::OK();
|
||||
}
|
||||
conv_B = std::make_unique<Initializer>(conv_B_tensor_proto);
|
||||
}
|
||||
|
||||
auto& conv_node = node;
|
||||
const Node& mul_node = next_node;
|
||||
// Calculate new value of initializers of conv node
|
||||
conv_W->scale_by_axis(*mul_B, 1);
|
||||
|
||||
const auto& conv_inputs = conv_node.InputDefs();
|
||||
const auto& mul_inputs = mul_node.InputDefs();
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* mul_B_tensor_proto = nullptr;
|
||||
graph.GetInitializedTensor(mul_inputs[1]->Name(), mul_B_tensor_proto);
|
||||
|
||||
if (!Initializer::IsSupportedDataType(conv_W_tensor_proto) ||
|
||||
!Initializer::IsSupportedDataType(mul_B_tensor_proto) ||
|
||||
conv_W_tensor_proto->data_type() != mul_B_tensor_proto->data_type() ||
|
||||
conv_W_tensor_proto->dims_size() < 4 ||
|
||||
!(mul_B_tensor_proto->dims_size() == 0 ||
|
||||
(mul_B_tensor_proto->dims_size() == conv_W_tensor_proto->dims_size() - 1 &&
|
||||
conv_W_tensor_proto->dims(0) == mul_B_tensor_proto->dims(0)))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// The dimensions of mul_B should be equal to 1 except first dimension.
|
||||
if (conv_inputs.size() == 3) {
|
||||
if (mul_B_tensor_proto->dims_size() != 0) {
|
||||
bool flag = false;
|
||||
for (int i = 1; i < mul_B_tensor_proto->dims_size(); i++) {
|
||||
if (mul_B_tensor_proto->dims(i) != 1) {
|
||||
flag = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (flag) {
|
||||
continue;
|
||||
}
|
||||
conv_B->mul(*mul_B);
|
||||
} else {
|
||||
conv_B->scale_by_axis(*mul_B, 0);
|
||||
}
|
||||
auto conv_W = std::make_unique<Initializer>(conv_W_tensor_proto);
|
||||
auto mul_B = std::make_unique<Initializer>(mul_B_tensor_proto);
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr;
|
||||
std::unique_ptr<Initializer> conv_B = nullptr;
|
||||
const bool is_3d = conv_inputs.size() == 3;
|
||||
if (is_3d) {
|
||||
if (!graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto))
|
||||
continue;
|
||||
if (conv_B_tensor_proto == nullptr)
|
||||
return Status(ONNXRUNTIME, FAIL, "Internal error in ConvMulFusion. conv_B_tensor_proto is NULL");
|
||||
if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
|
||||
conv_B_tensor_proto->data_type() != mul_B_tensor_proto->data_type() ||
|
||||
conv_B_tensor_proto->dims_size() != 1 || (mul_B_tensor_proto->dims_size() != 0 && conv_B_tensor_proto->dims(0) != mul_B_tensor_proto->dims(0))) {
|
||||
continue;
|
||||
}
|
||||
conv_B = std::make_unique<Initializer>(conv_B_tensor_proto);
|
||||
}
|
||||
|
||||
// Calculate new value of initializers of conv node
|
||||
conv_W->scale_by_axis(*mul_B, 1);
|
||||
|
||||
if (conv_inputs.size() == 3) {
|
||||
if (mul_B_tensor_proto->dims_size() != 0) {
|
||||
conv_B->mul(*mul_B);
|
||||
} else {
|
||||
conv_B->scale_by_axis(*mul_B, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Create new initializers of conv
|
||||
ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto);
|
||||
conv_W->ToProto(&new_conv_W_tensor_proto);
|
||||
|
||||
// Replace initializers of conv node
|
||||
graph.RemoveInitializedTensor(conv_inputs[1]->Name());
|
||||
graph.AddInitializedTensor(new_conv_W_tensor_proto);
|
||||
|
||||
if (is_3d) {
|
||||
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*conv_B_tensor_proto);
|
||||
conv_B->ToProto(&new_conv_B_tensor_proto);
|
||||
graph.RemoveInitializedTensor(conv_inputs[2]->Name());
|
||||
graph.AddInitializedTensor(new_conv_B_tensor_proto);
|
||||
}
|
||||
|
||||
// Replace the input of the node following mul node
|
||||
const NodeArg* mul_output_def = mul_node.OutputDefs()[0];
|
||||
NodeArg* conv_output_def = conv_node.MutableOutputDefs()[0];
|
||||
for (auto it = mul_node.OutputNodesBegin(); it != mul_node.OutputNodesEnd(); ++it) {
|
||||
auto output_node = graph.GetNode((*it).Index());
|
||||
if (!output_node) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
auto& input_defs = output_node->MutableInputDefs();
|
||||
for (auto& def : input_defs) {
|
||||
if (def == mul_output_def) {
|
||||
def = conv_output_def;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
removed_nodes.push_back(mul_node.Index());
|
||||
}
|
||||
|
||||
for (auto i : removed_nodes) {
|
||||
graph.RemoveNode(i);
|
||||
// Create new initializers of conv
|
||||
ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto);
|
||||
conv_W->ToProto(&new_conv_W_tensor_proto);
|
||||
|
||||
// Replace initializers of conv node
|
||||
graph.RemoveInitializedTensor(conv_inputs[1]->Name());
|
||||
graph.AddInitializedTensor(new_conv_W_tensor_proto);
|
||||
|
||||
if (is_3d) {
|
||||
ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*conv_B_tensor_proto);
|
||||
conv_B->ToProto(&new_conv_B_tensor_proto);
|
||||
graph.RemoveInitializedTensor(conv_inputs[2]->Name());
|
||||
graph.AddInitializedTensor(new_conv_B_tensor_proto);
|
||||
}
|
||||
|
||||
if (!removed_nodes.empty()) {
|
||||
modified = true;
|
||||
// Remove Mul node.
|
||||
auto* mul_node_to_remove = graph.GetNode(mul_node.Index());
|
||||
if (graph_utils::RemoveNode(graph, *mul_node_to_remove)) {
|
||||
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", 1) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", 7) ||
|
||||
next_node.GetInputEdgesCount() != 1 ||
|
||||
graph.IsNodeOutputsInGraphOutputs(next_node) ||
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -2,16 +2,29 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class ConvMulFusion : public onnxruntime::GraphTransformer {
|
||||
/**
|
||||
@Class ConvMulFusion
|
||||
|
||||
Rewrite rule that fuses two Conv+Mul nodes to a single Conv node.
|
||||
|
||||
It is attempted to be triggered only on nodes with op type "Conv".
|
||||
*/
|
||||
class ConvMulFusion : public RewriteRule {
|
||||
public:
|
||||
ConvMulFusion() noexcept : onnxruntime::GraphTransformer("ConvMulFusion") {}
|
||||
ConvMulFusion() noexcept : RewriteRule("ConvMulFusion") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"Conv"};
|
||||
}
|
||||
|
||||
private:
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class GraphTransformerManager {
|
|||
explicit GraphTransformerManager(unsigned steps) : steps_(steps) {
|
||||
}
|
||||
|
||||
// Register a transformer with a level and compatible providers list
|
||||
// Register a transformer with a level.
|
||||
common::Status Register(std::unique_ptr<GraphTransformer> transformer, TransformerLevel level);
|
||||
|
||||
// Apply all transformers registered for the given level on the given graph
|
||||
|
|
|
|||
|
|
@ -30,6 +30,9 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(TransformerLevel
|
|||
break;
|
||||
|
||||
case TransformerLevel::Level2:
|
||||
rules.push_back(std::make_unique<ConvAddFusion>());
|
||||
rules.push_back(std::make_unique<ConvMulFusion>());
|
||||
rules.push_back(std::make_unique<ConvBNFusion>());
|
||||
break;
|
||||
default:
|
||||
ORT_ENFORCE(false, "Unsupported level" + std::to_string(static_cast<uint32_t>(level)));
|
||||
|
|
@ -93,9 +96,6 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
|
|||
transformers.emplace_back(std::make_unique<MatMulAddFusion>(l2_execution_providers));
|
||||
transformers.emplace_back(std::make_unique<ConvActivationFusion>(l2_execution_providers));
|
||||
#endif
|
||||
transformers.emplace_back(std::make_unique<ConvAddFusion>());
|
||||
transformers.emplace_back(std::make_unique<ConvMulFusion>());
|
||||
transformers.emplace_back(std::make_unique<ConvBNFusion>());
|
||||
} break;
|
||||
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status EliminateIdentity::Apply(Graph& graph, Node& node, bool& modified, bool& deleted) {
|
||||
if (graph_utils::RemoveSingleInputNode(graph, node)) {
|
||||
modified = deleted = true;
|
||||
Status EliminateIdentity::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
|
||||
if (graph_utils::RemoveNode(graph, node)) {
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class EliminateIdentity : public RewriteRule {
|
|||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
|
||||
}; // namespace onnxruntime
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
using namespace ::onnxruntime::common;
|
||||
|
||||
|
|
@ -22,11 +23,11 @@ Status RuleBasedGraphTransformer::Register(std::unique_ptr<RewriteRule> rule) {
|
|||
|
||||
Status RuleBasedGraphTransformer::ApplyRulesOnNode(Graph& graph, Node& node,
|
||||
const std::vector<std::unique_ptr<RewriteRule>>& rules,
|
||||
bool& modified, bool& deleted) const {
|
||||
RuleEffect& rule_effect) const {
|
||||
for (const auto& rule : rules) {
|
||||
ORT_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, node, modified, deleted));
|
||||
if (deleted) {
|
||||
modified = true; // should be set by rewriter but in case it wasn't...
|
||||
ORT_RETURN_IF_ERROR(rule->CheckConditionAndApply(graph, node, rule_effect));
|
||||
// If the current node was removed as a result of a rule, stop rule application for that node.
|
||||
if (rule_effect == RuleEffect::kRemovedCurrentNode) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -39,10 +40,15 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
|
||||
for (NodeIndex i : order) {
|
||||
auto* node = graph.GetNode(i);
|
||||
// A node might not be found as it might have already been deleted from one of the rules.
|
||||
if (!node) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Initialize the effect of rules on this node to denote that the graph has not yet been modified
|
||||
// by the rule application on the current node.
|
||||
auto rule_effect = RuleEffect::kNone;
|
||||
|
||||
if (!graph_utils::IsSupportedProvider(*node, GetCompatibleExecutionProviders())) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -50,22 +56,26 @@ Status RuleBasedGraphTransformer::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
// First apply rewrite rules that are registered for the op type of the current node; then apply rules that are
|
||||
// registered to be applied regardless of the op type; then recursively apply rules to subgraphs (if any).
|
||||
// Stop further rule application for the current node, if the node gets removed by a rule.
|
||||
bool deleted = false;
|
||||
const std::vector<std::unique_ptr<RewriteRule>>* rules = nullptr;
|
||||
|
||||
rules = GetRewriteRulesForOpType(node->OpType());
|
||||
if (rules) {
|
||||
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, modified, deleted));
|
||||
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect));
|
||||
}
|
||||
|
||||
if (!deleted) {
|
||||
if (rule_effect != RuleEffect::kRemovedCurrentNode) {
|
||||
rules = GetAnyOpRewriteRules();
|
||||
if (rules) {
|
||||
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, modified, deleted));
|
||||
ORT_RETURN_IF_ERROR(ApplyRulesOnNode(graph, *node, *rules, rule_effect));
|
||||
}
|
||||
}
|
||||
|
||||
if (!deleted) {
|
||||
// Update the modified field of the rule-based transformer.
|
||||
if (rule_effect != RuleEffect::kNone) {
|
||||
modified = true;
|
||||
}
|
||||
|
||||
if (rule_effect != RuleEffect::kRemovedCurrentNode) {
|
||||
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status EliminateSlice::Apply(Graph& graph, Node& node, bool& modified, bool& removed) {
|
||||
if (graph_utils::RemoveSingleInputNode(graph, node)) {
|
||||
removed = modified = true;
|
||||
Status EliminateSlice::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
|
||||
if (graph_utils::RemoveNode(graph, node)) {
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class EliminateSlice : public RewriteRule {
|
|||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, bool& modified, bool& removed) override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ using namespace ::onnxruntime::common;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status UnsqueezeElimination::Apply(Graph& graph, Node& node, bool& modified, bool& removed) {
|
||||
Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
|
||||
// Get "axes" attribute.
|
||||
const ONNX_NAMESPACE::AttributeProto* attr = graph_utils::GetNodeAttribute(node, "axes");
|
||||
if (attr == nullptr || attr->type() != AttributeProto_AttributeType_INTS) {
|
||||
|
|
@ -66,8 +66,8 @@ Status UnsqueezeElimination::Apply(Graph& graph, Node& node, bool& modified, boo
|
|||
input_def->SetShape(shape);
|
||||
|
||||
// Remove Unsqueeze node.
|
||||
if (graph_utils::RemoveSingleInputNode(graph, node)) {
|
||||
removed = modified = true;
|
||||
if (graph_utils::RemoveNode(graph, node)) {
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class UnsqueezeElimination : public RewriteRule {
|
|||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node) override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, bool& modified, bool& deleted) override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ static void UpdateSubgraphWhenRemovingNode(bool include_nested = false) {
|
|||
auto& node_to_remove = *graph.GetNode(1);
|
||||
const auto& if_node = *graph.GetNode(2);
|
||||
|
||||
bool removed = graph_utils::RemoveSingleInputNode(graph, node_to_remove);
|
||||
bool removed = graph_utils::RemoveNode(graph, node_to_remove);
|
||||
ASSERT_TRUE(removed);
|
||||
|
||||
// check subgraph implicit input was updated
|
||||
|
|
@ -238,7 +238,7 @@ static void DontRemoveNodeIfItWillBreakSubgraph(bool test_nested = false) {
|
|||
auto& graph = model.MainGraph();
|
||||
auto& node_to_remove = *graph.GetNode(1);
|
||||
|
||||
bool removed = graph_utils::RemoveSingleInputNode(graph, node_to_remove);
|
||||
bool removed = graph_utils::RemoveNode(graph, node_to_remove);
|
||||
ASSERT_FALSE(removed);
|
||||
}
|
||||
|
||||
|
|
@ -287,7 +287,7 @@ TEST(GraphUtils, TestMultiEdgeRemovalNodes) {
|
|||
ASSERT_EQ(nodes[2]->GetOutputEdgesCount(), 2);
|
||||
|
||||
// Remove id_2. This leaves id_0 with 3 output edges. id_0 is now incoming node to id_3 and id_4.
|
||||
ASSERT_TRUE(graph_utils::RemoveSingleInputNode(graph, *nodes[2]));
|
||||
ASSERT_TRUE(graph_utils::RemoveNode(graph, *nodes[2]));
|
||||
ASSERT_EQ(graph.NumberOfNodes(), 4);
|
||||
ASSERT_EQ(nodes[0]->GetOutputEdgesCount(), 3);
|
||||
ASSERT_EQ(nodes[3]->InputDefs().size(), 1);
|
||||
|
|
@ -296,7 +296,7 @@ TEST(GraphUtils, TestMultiEdgeRemovalNodes) {
|
|||
ASSERT_TRUE(nodes[4]->InputDefs()[0]->Name() == "id_0_out");
|
||||
|
||||
// Remove id_0
|
||||
ASSERT_TRUE(graph_utils::RemoveSingleInputNode(graph, *nodes[0]));
|
||||
ASSERT_TRUE(graph_utils::RemoveNode(graph, *nodes[0]));
|
||||
ASSERT_EQ(graph.NumberOfNodes(), 3);
|
||||
ASSERT_TRUE(nodes[1]->InputDefs()[0]->Name() == "id_0_in");
|
||||
ASSERT_TRUE(nodes[3]->InputDefs()[0]->Name() == "id_0_in");
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class DummyRewriteRule : public RewriteRule {
|
|||
return true;
|
||||
}
|
||||
|
||||
Status Apply(Graph& /*graph*/, Node& /*node*/, bool& /*modified*/, bool& /*deleted*/) override {
|
||||
Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/) override {
|
||||
rewrite_rule_invoked_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,10 +52,10 @@ TEST(GraphTransformationTests, IdentityElimination) {
|
|||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Identity"] == 1);
|
||||
|
||||
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer->Register(std::make_unique<EliminateIdentity>());
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(std::make_unique<EliminateIdentity>());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1);
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -70,10 +70,10 @@ TEST(GraphTransformationTests, SliceElimination) {
|
|||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Slice"] == 5);
|
||||
|
||||
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer->Register(std::make_unique<EliminateSlice>());
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(std::make_unique<EliminateSlice>());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1);
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
|
|
@ -129,7 +129,9 @@ TEST(GraphTransformationTests, FuseConvBNNoBias) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(std::make_unique<ConvBNFusion>(), TransformerLevel::Level2);
|
||||
auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
|
||||
rule_transformer_L2->Register(std::make_unique<ConvBNFusion>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2);
|
||||
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK());
|
||||
|
||||
|
|
@ -148,12 +150,15 @@ TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer->Register(std::make_unique<UnsqueezeElimination>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1);
|
||||
graph_transformation_mgr.Register(std::make_unique<ConvBNFusion>(), TransformerLevel::Level2);
|
||||
graph_transformation_mgr.Register(std::make_unique<ConvMulFusion>(), TransformerLevel::Level2);
|
||||
graph_transformation_mgr.Register(std::make_unique<ConvAddFusion>(), TransformerLevel::Level2);
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(std::make_unique<UnsqueezeElimination>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
|
||||
auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
|
||||
rule_transformer_L2->Register(std::make_unique<ConvAddFusion>());
|
||||
rule_transformer_L2->Register(std::make_unique<ConvBNFusion>());
|
||||
rule_transformer_L2->Register(std::make_unique<ConvMulFusion>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2);
|
||||
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK());
|
||||
|
|
@ -201,10 +206,13 @@ TEST(GraphTransformationTests, FuseConvMulNoBias) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer->Register(std::make_unique<UnsqueezeElimination>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1);
|
||||
graph_transformation_mgr.Register(std::make_unique<ConvMulFusion>(), TransformerLevel::Level2);
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(std::make_unique<UnsqueezeElimination>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
|
||||
auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
|
||||
rule_transformer_L2->Register(std::make_unique<ConvMulFusion>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2);
|
||||
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK());
|
||||
|
|
@ -222,10 +230,13 @@ TEST(GraphTransformationTests, FuseConvAddNoBias) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer->Register(std::make_unique<UnsqueezeElimination>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer), TransformerLevel::Level1);
|
||||
graph_transformation_mgr.Register(std::make_unique<ConvAddFusion>(), TransformerLevel::Level2);
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(std::make_unique<UnsqueezeElimination>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
|
||||
auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
|
||||
rule_transformer_L2->Register(std::make_unique<ConvAddFusion>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2);
|
||||
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK());
|
||||
|
|
@ -243,8 +254,10 @@ TEST(GraphTransformationTests, FuseConvAddMul3D) {
|
|||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(std::make_unique<ConvMulFusion>(), TransformerLevel::Level2);
|
||||
graph_transformation_mgr.Register(std::make_unique<ConvAddFusion>(), TransformerLevel::Level2);
|
||||
auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
|
||||
rule_transformer_L2->Register(std::make_unique<ConvAddFusion>());
|
||||
rule_transformer_L2->Register(std::make_unique<ConvMulFusion>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2);
|
||||
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK());
|
||||
|
||||
|
|
@ -315,12 +328,11 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) {
|
|||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
|
||||
|
||||
std::unique_ptr<ConvBNFusion> ConvBNFusion_transformer = std::make_unique<ConvBNFusion>();
|
||||
std::unique_ptr<ConvMulFusion> ConvMulFusion_transformer = std::make_unique<ConvMulFusion>();
|
||||
std::unique_ptr<ConvAddFusion> ConvAddFusion_transformer = std::make_unique<ConvAddFusion>();
|
||||
session_object.RegisterGraphTransformer(std::move(ConvBNFusion_transformer));
|
||||
session_object.RegisterGraphTransformer(std::move(ConvMulFusion_transformer));
|
||||
session_object.RegisterGraphTransformer(std::move(ConvAddFusion_transformer));
|
||||
auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
|
||||
rule_transformer_L2->Register(std::make_unique<ConvAddFusion>());
|
||||
rule_transformer_L2->Register(std::make_unique<ConvBNFusion>());
|
||||
rule_transformer_L2->Register(std::make_unique<ConvMulFusion>());
|
||||
session_object.RegisterGraphTransformer(std::move(rule_transformer_L2), TransformerLevel::Level2);
|
||||
|
||||
ASSERT_TRUE(session_object.Initialize().IsOK());
|
||||
|
||||
|
|
@ -335,7 +347,8 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) {
|
|||
for (int i = 0; i < 9; ++i) {
|
||||
values_x.push_back(x_f);
|
||||
}
|
||||
CreateMLValue<MLFloat16>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_x, values_x, &ml_value_x);
|
||||
CreateMLValue<MLFloat16>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault),
|
||||
dims_x, values_x, &ml_value_x);
|
||||
feeds.insert(std::make_pair("X", ml_value_x));
|
||||
|
||||
std::vector<std::string> output_names;
|
||||
|
|
@ -355,7 +368,8 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) {
|
|||
auto& rtensor = fetches.front().Get<Tensor>();
|
||||
TensorShape expected_shape(expected_dims_prod);
|
||||
ASSERT_EQ(expected_shape, rtensor.Shape());
|
||||
const std::vector<MLFloat16> found(rtensor.template Data<MLFloat16>(), rtensor.template Data<MLFloat16>() + expected_dims_prod.size());
|
||||
const std::vector<MLFloat16> found(rtensor.template Data<MLFloat16>(),
|
||||
rtensor.template Data<MLFloat16>() + expected_dims_prod.size());
|
||||
ASSERT_EQ(expected_values_prod, found);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) {
|
|||
ASSERT_TRUE(transformers.size() != 0);
|
||||
|
||||
// Transformer name match test
|
||||
std::vector<std::string> custom_list = {"EliminateIdentity", "ConvAddFusion", "ConvMulFusion", "abc", "def"};
|
||||
std::vector<std::string> custom_list = {"EliminateIdentity", "GemmActivationFusion", "MatMulAddFusion", "abc", "def"};
|
||||
transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level2, custom_list);
|
||||
ASSERT_TRUE(transformers.size() == 2);
|
||||
// validate each rule returned is present in the custom list
|
||||
|
|
@ -46,7 +46,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) {
|
|||
ASSERT_TRUE(std::find(custom_list.begin(), custom_list.end(), transformer->Name()) != custom_list.end());
|
||||
}
|
||||
|
||||
// Transformer name no match test. When there is no match empty list is expected.
|
||||
// Transformer name no-match test. When there is no match, empty list is expected.
|
||||
custom_list = {"EliminateIdentity"};
|
||||
transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level2, custom_list);
|
||||
ASSERT_TRUE(transformers.size() == 0);
|
||||
|
|
@ -56,7 +56,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers_CustomList) {
|
|||
// custom list of rules and transformers
|
||||
std::string l1_rule1 = "EliminateIdentity";
|
||||
std::string l1_transformer = "ConstantFolding";
|
||||
std::string l2_transformer = "ConvAddFusion";
|
||||
std::string l2_transformer = "GemmActivationFusion";
|
||||
std::vector<std::string> custom_list = {l1_rule1, l1_transformer, l2_transformer};
|
||||
|
||||
auto transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level1, custom_list);
|
||||
|
|
|
|||
Loading…
Reference in a new issue