mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Not-where fusion (#7182)
* Not-where fusion * Change to rewrite rule * Add to inference transforms * Support numtiple where consumers * review comments
This commit is contained in:
parent
790fc11e60
commit
2aa89989c4
7 changed files with 261 additions and 0 deletions
|
|
@ -31,6 +31,7 @@
|
|||
#include "core/optimizer/matmul_scale_fusion.h"
|
||||
#include "core/optimizer/nchwc_transformer.h"
|
||||
#include "core/optimizer/nhwc_transformer.h"
|
||||
#include "core/optimizer/not_where_fusion.h"
|
||||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/optimizer/reshape_fusion.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
|
|
@ -68,6 +69,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
|
|||
rules.push_back(onnxruntime::make_unique<CastElimination>());
|
||||
rules.push_back(onnxruntime::make_unique<DivMulFusion>());
|
||||
rules.push_back(onnxruntime::make_unique<FuseReluClip>());
|
||||
rules.push_back(onnxruntime::make_unique<NotWhereFusion>());
|
||||
rules.push_back(onnxruntime::make_unique<ShapeToInitializer>());
|
||||
rules.push_back(onnxruntime::make_unique<ConvAddFusion>());
|
||||
rules.push_back(onnxruntime::make_unique<ConvMulFusion>());
|
||||
|
|
|
|||
137
onnxruntime/core/optimizer/not_where_fusion.cc
Normal file
137
onnxruntime/core/optimizer/not_where_fusion.cc
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/not_where_fusion.h"
|
||||
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
Transform that fuses two Not -> Where nodes to a single Where node
|
||||
with the where inputs 1 and 2 flipped.
|
||||
Condition -> Not -> Where ->
|
||||
value0-| |
|
||||
value1----|
|
||||
|
||||
Condition -> Where ->
|
||||
value1-| |
|
||||
value0----|
|
||||
|
||||
It also fuses when not node has multiple where consumer nodes:
|
||||
|
||||
Condition -> Not -> Where ->
|
||||
| v0-| |
|
||||
| v1----|
|
||||
|----> Where ->
|
||||
v0-| |
|
||||
v1----|
|
||||
|
||||
Condition -> Where ->
|
||||
| v1-| |
|
||||
| v0----|
|
||||
|----> Where ->
|
||||
v1-| |
|
||||
v0----|
|
||||
*/
|
||||
bool NotWhereFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Where", {9})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const Node* p_not_node = graph_utils::GetInputNode(node, 0);
|
||||
if (p_not_node == nullptr ||
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(*p_not_node, "Not", {1}) ||
|
||||
// Make sure the two nodes do not span execution providers.
|
||||
p_not_node->GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (p_not_node->GetOutputEdgesCount() > 1) {
|
||||
// all consumers of not must be where
|
||||
for (auto it = p_not_node->OutputNodesBegin(); it != p_not_node->OutputNodesEnd(); ++it) {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*it, "Where", {9})) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!graph_utils::CanRemoveNode(graph, *p_not_node, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Status NotWhereFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
const Node* p_not_node = graph_utils::GetInputNode(node, 0);
|
||||
|
||||
auto& not_node = *graph.GetNode(p_not_node->Index()); // get mutable next node
|
||||
NodeArg* not_input = not_node.MutableInputDefs()[0];
|
||||
|
||||
// get all node ids of consumer where nodes
|
||||
std::vector<NodeIndex> where_node_ids;
|
||||
for (auto it = p_not_node->OutputNodesBegin(); it != p_not_node->OutputNodesEnd(); ++it) {
|
||||
where_node_ids.push_back(it->Index());
|
||||
}
|
||||
|
||||
// Move input egdes from not_node to all where_node
|
||||
const Node* not_input_node = graph_utils::GetInputNode(not_node, 0);
|
||||
if (not_input_node) {
|
||||
Node& replacement = *graph.GetNode(not_input_node->Index());
|
||||
int replacement_output_idx = graph_utils::GetNodeOutputIndexFromOutputName(replacement, not_input->Name());
|
||||
// Replace inputs of all downstream where nodes with input of not_node by
|
||||
// removing not's output edges, updating input names of not's consumers,
|
||||
// and adding the edges from not's input to not's consumers.
|
||||
graph_utils::ReplaceDownstreamNodeInput(graph, not_node, 0, replacement, replacement_output_idx);
|
||||
} else { // not's input is graph input/initializer. Remove the output egdes for not_node
|
||||
graph_utils::RemoveNodeOutputEdges(graph, not_node);
|
||||
}
|
||||
|
||||
for (auto it = where_node_ids.begin(); it != where_node_ids.end(); ++it) {
|
||||
auto& where_node = *graph.GetNode(*it);
|
||||
|
||||
std::vector<NodeArg*> where_inputs = where_node.MutableInputDefs();
|
||||
|
||||
if (!not_input_node) { // not's input is graph input/initializer.
|
||||
graph_utils::ReplaceNodeInput(where_node, 0, *not_input);
|
||||
}
|
||||
|
||||
const Node* where_input1_node = graph_utils::GetInputNode(where_node, 1);
|
||||
const Node* where_input2_node = graph_utils::GetInputNode(where_node, 2);
|
||||
|
||||
int output1_idx = -1, output2_idx = -1;
|
||||
if (where_input1_node) {
|
||||
output1_idx = graph_utils::GetNodeOutputIndexFromOutputName(*where_input1_node, where_inputs[1]->Name());
|
||||
graph.RemoveEdge(where_input1_node->Index(), where_node.Index(), output1_idx, 1);
|
||||
}
|
||||
|
||||
if (where_input2_node) {
|
||||
output2_idx = graph_utils::GetNodeOutputIndexFromOutputName(*where_input2_node, where_inputs[2]->Name());
|
||||
graph.RemoveEdge(where_input2_node->Index(), where_node.Index(), output2_idx, 2);
|
||||
}
|
||||
|
||||
graph_utils::ReplaceNodeInput(where_node, 1, *where_inputs[2]);
|
||||
graph_utils::ReplaceNodeInput(where_node, 2, *where_inputs[1]);
|
||||
|
||||
if (where_input1_node) {
|
||||
graph.AddEdge(where_input1_node->Index(), where_node.Index(), output1_idx, 2);
|
||||
}
|
||||
|
||||
if (where_input2_node) {
|
||||
graph.AddEdge(where_input2_node->Index(), where_node.Index(), output2_idx, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// remove not_node
|
||||
graph.RemoveNode(not_node.Index());
|
||||
|
||||
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
36
onnxruntime/core/optimizer/not_where_fusion.h
Normal file
36
onnxruntime/core/optimizer/not_where_fusion.h
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
/**
|
||||
@Class NotWhereFusion
|
||||
|
||||
Rewrite rule that fuses two Not -> Where nodes to a single Where node
|
||||
with the where inputs 1 and 2 flipped.
|
||||
Condition -> Not -> Where ->
|
||||
value0-| |
|
||||
value1----|
|
||||
|
||||
Condition -> Where ->
|
||||
value1-| |
|
||||
value0----|
|
||||
*/
|
||||
class NotWhereFusion : public RewriteRule {
|
||||
public:
|
||||
NotWhereFusion() noexcept : RewriteRule("NotWhereFusion") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"Where"};
|
||||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -46,6 +46,7 @@
|
|||
#include "core/optimizer/matmul_integer_to_float.h"
|
||||
#include "core/optimizer/matmul_scale_fusion.h"
|
||||
#include "core/optimizer/matmul_transpose_fusion.h"
|
||||
#include "core/optimizer/not_where_fusion.h"
|
||||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/optimizer/reshape_fusion.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
|
|
@ -548,6 +549,26 @@ TEST_F(GraphTransformationTests, DivMulFusion) {
|
|||
ASSERT_TRUE(op_to_count["Mul"] == 2);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, NotWhereFusion) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/not_where.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
|
||||
Graph& graph = model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Not"] == 4);
|
||||
ASSERT_TRUE(op_to_count["Where"] == 5);
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(onnxruntime::make_unique<NotWhereFusion>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Where"] == 5);
|
||||
ASSERT_TRUE(op_to_count["Not"] == 1); // can't remove Not if it is graph output/ has consumer that's not where
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA) && !defined(DISABLE_CONTRIB_OPS)
|
||||
TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx";
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/not_where.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/not_where.onnx
vendored
Normal file
Binary file not shown.
63
onnxruntime/test/testdata/transform/fusion/not_where.py
vendored
Normal file
63
onnxruntime/test/testdata/transform/fusion/not_where.py
vendored
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto, OperatorSetIdProto
|
||||
from enum import Enum
|
||||
|
||||
opsets = []
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 12
|
||||
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
opsets.append(onnxdomain)
|
||||
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = 'com.microsoft'
|
||||
|
||||
opsets.append(msdomain)
|
||||
kwargs={}
|
||||
kwargs['opset_imports'] = opsets
|
||||
|
||||
def GenerateModel(model_name):
|
||||
nodes = [ # subgraph
|
||||
# float
|
||||
helper.make_node("Not", ["X"], ["not_X_1"], "not_1"),
|
||||
helper.make_node("Where", ["not_X_1", "v0", "v1"], ["Y1"], "where_1"),
|
||||
helper.make_node("Not", ["not_X_1"], ["x"], "not_2"),
|
||||
helper.make_node("Identity", ["v0"], ["v0_edge"], "identity_v0"),
|
||||
helper.make_node("Identity", ["v1"], ["v1_edge"], "identity_v1"),
|
||||
helper.make_node("Where", ["x", "v0_edge", "v1_edge"], ["Y2"], "where_2"),
|
||||
helper.make_node("Not", ["X"], ["not_X_2"], "not_3"),
|
||||
helper.make_node("Where", ["not_X_2", "v0", "v1"], ["Y3"], "where_3"),
|
||||
helper.make_node("Not", ["X"], ["not_X_3"], "not_4"),
|
||||
helper.make_node("Where", ["not_X_3", "v0", "v1"], ["Y4"], "where_4"),
|
||||
helper.make_node("Where", ["not_X_3", "v0", "v1"], ["Y5"], "where_5"),
|
||||
]
|
||||
|
||||
inputs = [ # inputs
|
||||
helper.make_tensor_value_info('X', TensorProto.BOOL, ['M', 'K']),
|
||||
]
|
||||
|
||||
initializers = [
|
||||
helper.make_tensor('v0', TensorProto.FLOAT, [1], [1.0]),
|
||||
helper.make_tensor('v1', TensorProto.FLOAT, [1], [-1.0]),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"NotWhere", #name
|
||||
inputs,
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('not_X_2', TensorProto.BOOL, ['M', 'K']),
|
||||
helper.make_tensor_value_info('Y1', TensorProto.FLOAT, ['M', 'K']),
|
||||
helper.make_tensor_value_info('Y2', TensorProto.FLOAT, ['M', 'K']),
|
||||
helper.make_tensor_value_info('Y3', TensorProto.FLOAT, ['M', 'K']),
|
||||
helper.make_tensor_value_info('Y4', TensorProto.FLOAT, ['M', 'K']),
|
||||
helper.make_tensor_value_info('Y5', TensorProto.FLOAT, ['M', 'K']),
|
||||
],
|
||||
initializers)
|
||||
|
||||
model = helper.make_model(graph, **kwargs)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
GenerateModel('not_where.onnx')
|
||||
|
|
@ -31,6 +31,7 @@
|
|||
#include "core/optimizer/matmul_scale_fusion.h"
|
||||
#include "core/optimizer/matmul_transpose_fusion.h"
|
||||
#include "core/optimizer/nchwc_transformer.h"
|
||||
#include "core/optimizer/not_where_fusion.h"
|
||||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/optimizer/reshape_fusion.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
|
|
@ -76,6 +77,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
rule_transformer->Register(make_unique<CastElimination>());
|
||||
rule_transformer->Register(make_unique<DivMulFusion>());
|
||||
rule_transformer->Register(make_unique<EliminateDropout>());
|
||||
rule_transformer->Register(make_unique<NotWhereFusion>());
|
||||
rule_transformer->Register(make_unique<NonZeroShapeSetter>());
|
||||
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue