Not-where fusion (#7182)

* Not-where fusion

* Change to rewrite rule

* Add to inference transforms

* Support numtiple where consumers

* review comments
This commit is contained in:
ashbhandare 2021-04-06 16:12:26 -07:00 committed by GitHub
parent 790fc11e60
commit 2aa89989c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 261 additions and 0 deletions

View file

@ -31,6 +31,7 @@
#include "core/optimizer/matmul_scale_fusion.h"
#include "core/optimizer/nchwc_transformer.h"
#include "core/optimizer/nhwc_transformer.h"
#include "core/optimizer/not_where_fusion.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
@ -68,6 +69,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
rules.push_back(onnxruntime::make_unique<CastElimination>());
rules.push_back(onnxruntime::make_unique<DivMulFusion>());
rules.push_back(onnxruntime::make_unique<FuseReluClip>());
rules.push_back(onnxruntime::make_unique<NotWhereFusion>());
rules.push_back(onnxruntime::make_unique<ShapeToInitializer>());
rules.push_back(onnxruntime::make_unique<ConvAddFusion>());
rules.push_back(onnxruntime::make_unique<ConvMulFusion>());

View file

@ -0,0 +1,137 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/not_where_fusion.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
/**
Transform that fuses two Not -> Where nodes to a single Where node
with the where inputs 1 and 2 flipped.
Condition -> Not -> Where ->
value0-| |
value1----|
Condition -> Where ->
value1-| |
value0----|
It also fuses when not node has multiple where consumer nodes:
Condition -> Not -> Where ->
| v0-| |
| v1----|
|----> Where ->
v0-| |
v1----|
Condition -> Where ->
| v1-| |
| v0----|
|----> Where ->
v1-| |
v0----|
*/
bool NotWhereFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Where", {9})) {
return false;
}
const Node* p_not_node = graph_utils::GetInputNode(node, 0);
if (p_not_node == nullptr ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*p_not_node, "Not", {1}) ||
// Make sure the two nodes do not span execution providers.
p_not_node->GetExecutionProviderType() != node.GetExecutionProviderType()) {
return false;
}
if (p_not_node->GetOutputEdgesCount() > 1) {
// all consumers of not must be where
for (auto it = p_not_node->OutputNodesBegin(); it != p_not_node->OutputNodesEnd(); ++it) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*it, "Where", {9})) {
return false;
}
}
}
if (!graph_utils::CanRemoveNode(graph, *p_not_node, logger)) {
return false;
}
return true;
}
Status NotWhereFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
const Node* p_not_node = graph_utils::GetInputNode(node, 0);
auto& not_node = *graph.GetNode(p_not_node->Index()); // get mutable next node
NodeArg* not_input = not_node.MutableInputDefs()[0];
// get all node ids of consumer where nodes
std::vector<NodeIndex> where_node_ids;
for (auto it = p_not_node->OutputNodesBegin(); it != p_not_node->OutputNodesEnd(); ++it) {
where_node_ids.push_back(it->Index());
}
// Move input egdes from not_node to all where_node
const Node* not_input_node = graph_utils::GetInputNode(not_node, 0);
if (not_input_node) {
Node& replacement = *graph.GetNode(not_input_node->Index());
int replacement_output_idx = graph_utils::GetNodeOutputIndexFromOutputName(replacement, not_input->Name());
// Replace inputs of all downstream where nodes with input of not_node by
// removing not's output edges, updating input names of not's consumers,
// and adding the edges from not's input to not's consumers.
graph_utils::ReplaceDownstreamNodeInput(graph, not_node, 0, replacement, replacement_output_idx);
} else { // not's input is graph input/initializer. Remove the output egdes for not_node
graph_utils::RemoveNodeOutputEdges(graph, not_node);
}
for (auto it = where_node_ids.begin(); it != where_node_ids.end(); ++it) {
auto& where_node = *graph.GetNode(*it);
std::vector<NodeArg*> where_inputs = where_node.MutableInputDefs();
if (!not_input_node) { // not's input is graph input/initializer.
graph_utils::ReplaceNodeInput(where_node, 0, *not_input);
}
const Node* where_input1_node = graph_utils::GetInputNode(where_node, 1);
const Node* where_input2_node = graph_utils::GetInputNode(where_node, 2);
int output1_idx = -1, output2_idx = -1;
if (where_input1_node) {
output1_idx = graph_utils::GetNodeOutputIndexFromOutputName(*where_input1_node, where_inputs[1]->Name());
graph.RemoveEdge(where_input1_node->Index(), where_node.Index(), output1_idx, 1);
}
if (where_input2_node) {
output2_idx = graph_utils::GetNodeOutputIndexFromOutputName(*where_input2_node, where_inputs[2]->Name());
graph.RemoveEdge(where_input2_node->Index(), where_node.Index(), output2_idx, 2);
}
graph_utils::ReplaceNodeInput(where_node, 1, *where_inputs[2]);
graph_utils::ReplaceNodeInput(where_node, 2, *where_inputs[1]);
if (where_input1_node) {
graph.AddEdge(where_input1_node->Index(), where_node.Index(), output1_idx, 2);
}
if (where_input2_node) {
graph.AddEdge(where_input2_node->Index(), where_node.Index(), output2_idx, 1);
}
}
// remove not_node
graph.RemoveNode(not_node.Index());
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
/**
@Class NotWhereFusion
Rewrite rule that fuses two Not -> Where nodes to a single Where node
with the where inputs 1 and 2 flipped.
Condition -> Not -> Where ->
value0-| |
value1----|
Condition -> Where ->
value1-| |
value0----|
*/
class NotWhereFusion : public RewriteRule {
public:
NotWhereFusion() noexcept : RewriteRule("NotWhereFusion") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Where"};
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -46,6 +46,7 @@
#include "core/optimizer/matmul_integer_to_float.h"
#include "core/optimizer/matmul_scale_fusion.h"
#include "core/optimizer/matmul_transpose_fusion.h"
#include "core/optimizer/not_where_fusion.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
@ -548,6 +549,26 @@ TEST_F(GraphTransformationTests, DivMulFusion) {
ASSERT_TRUE(op_to_count["Mul"] == 2);
}
TEST_F(GraphTransformationTests, NotWhereFusion) {
auto model_uri = MODEL_FOLDER "fusion/not_where.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Not"] == 4);
ASSERT_TRUE(op_to_count["Where"] == 5);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(onnxruntime::make_unique<NotWhereFusion>());
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Where"] == 5);
ASSERT_TRUE(op_to_count["Not"] == 1); // can't remove Not if it is graph output/ has consumer that's not where
}
#if defined(USE_CUDA) && !defined(DISABLE_CONTRIB_OPS)
TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) {
auto model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx";

Binary file not shown.

View file

@ -0,0 +1,63 @@
import onnx
from onnx import helper
from onnx import TensorProto, OperatorSetIdProto
from enum import Enum
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = 'com.microsoft'
opsets.append(msdomain)
kwargs={}
kwargs['opset_imports'] = opsets
def GenerateModel(model_name):
nodes = [ # subgraph
# float
helper.make_node("Not", ["X"], ["not_X_1"], "not_1"),
helper.make_node("Where", ["not_X_1", "v0", "v1"], ["Y1"], "where_1"),
helper.make_node("Not", ["not_X_1"], ["x"], "not_2"),
helper.make_node("Identity", ["v0"], ["v0_edge"], "identity_v0"),
helper.make_node("Identity", ["v1"], ["v1_edge"], "identity_v1"),
helper.make_node("Where", ["x", "v0_edge", "v1_edge"], ["Y2"], "where_2"),
helper.make_node("Not", ["X"], ["not_X_2"], "not_3"),
helper.make_node("Where", ["not_X_2", "v0", "v1"], ["Y3"], "where_3"),
helper.make_node("Not", ["X"], ["not_X_3"], "not_4"),
helper.make_node("Where", ["not_X_3", "v0", "v1"], ["Y4"], "where_4"),
helper.make_node("Where", ["not_X_3", "v0", "v1"], ["Y5"], "where_5"),
]
inputs = [ # inputs
helper.make_tensor_value_info('X', TensorProto.BOOL, ['M', 'K']),
]
initializers = [
helper.make_tensor('v0', TensorProto.FLOAT, [1], [1.0]),
helper.make_tensor('v1', TensorProto.FLOAT, [1], [-1.0]),
]
graph = helper.make_graph(
nodes,
"NotWhere", #name
inputs,
[ # outputs
helper.make_tensor_value_info('not_X_2', TensorProto.BOOL, ['M', 'K']),
helper.make_tensor_value_info('Y1', TensorProto.FLOAT, ['M', 'K']),
helper.make_tensor_value_info('Y2', TensorProto.FLOAT, ['M', 'K']),
helper.make_tensor_value_info('Y3', TensorProto.FLOAT, ['M', 'K']),
helper.make_tensor_value_info('Y4', TensorProto.FLOAT, ['M', 'K']),
helper.make_tensor_value_info('Y5', TensorProto.FLOAT, ['M', 'K']),
],
initializers)
model = helper.make_model(graph, **kwargs)
onnx.save(model, model_name)
if __name__ == "__main__":
GenerateModel('not_where.onnx')

View file

@ -31,6 +31,7 @@
#include "core/optimizer/matmul_scale_fusion.h"
#include "core/optimizer/matmul_transpose_fusion.h"
#include "core/optimizer/nchwc_transformer.h"
#include "core/optimizer/not_where_fusion.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
@ -76,6 +77,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
rule_transformer->Register(make_unique<CastElimination>());
rule_transformer->Register(make_unique<DivMulFusion>());
rule_transformer->Register(make_unique<EliminateDropout>());
rule_transformer->Register(make_unique<NotWhereFusion>());
rule_transformer->Register(make_unique<NonZeroShapeSetter>());
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());