From f487cc0b2835892dcd223b9eb9093f62e64ea294 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 28 Apr 2020 00:03:16 -0700 Subject: [PATCH] Fix Reshape Fusion with graph inputs (#3729) Use NodeArg to check root input; Add a check on constant initializer --- onnxruntime/core/optimizer/reshape_fusion.cc | 17 ++++++----- onnxruntime/core/optimizer/utils.cc | 6 +++- onnxruntime/core/optimizer/utils.h | 2 +- .../test/optimizer/graph_transform_test.cc | 21 ++++++++++++++ .../transform/fusion/reshape_fusion_gen.py | 27 ++++++++++++++++++ .../reshape_fusion_with_graph_inputs.onnx | Bin 0 -> 446 bytes 6 files changed, 62 insertions(+), 11 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/reshape_fusion_with_graph_inputs.onnx diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index 5d39939d30..399dcf744c 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -46,7 +46,7 @@ each of which is a constant initializer or a Shape->Gather->Unsqueeze chain with index corresponding to the index of the argument.) Before fusion: - [Sub-graph Root Node ] + [Sub-graph Root] | / \ | Shape Shape | | | @@ -61,13 +61,14 @@ Before fusion: Reshape After fusion: - [Sub-graph Root Node] (Constant Initializer) + [Sub-graph Root] (Constant Initializer) \ [0, a, 0, b] \ / Reshape */ bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::Logger& logger) { - const Node* p_root = graph_utils::GetInputNode(reshape, 0); + // The root could be either a graph input or a node so use node arg to compare. + const NodeArg& root_input = *(reshape.InputDefs()[0]); const Node* p_concat = graph_utils::GetInputNode(reshape, 1); if (nullptr == p_concat) { @@ -90,11 +91,8 @@ bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::L enum class NodeType { Unsqueeze, Gather, Shape }; std::set> candidates_for_removal; for (int i = 0; i < concat_input_count; ++i) { - // First check if the i-th argument is an initializer. - // We do not check whether the initializer is constant. - // Some model uses constant initializer and some does not. - // Here we assume that no one will override the initializer using graph input. - if (optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[i]), shape_value)) { + // First check if the i-th argument is a constant initializer. + if (optimizer_utils::AppendTensorFromInitializer(graph, *(concat.InputDefs()[i]), shape_value, true)) { continue; } @@ -113,7 +111,8 @@ bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::L const Node& gather = edges[1]->GetNode(); const Node& shape = edges[2]->GetNode(); - if (graph_utils::GetInputNode(shape, 0) != p_root) { + const NodeArg& shape_input = *(shape.InputDefs()[0]); + if (shape_input.Name() != root_input.Name()) { return false; } diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc index 4aff66352a..052f5ff67f 100644 --- a/onnxruntime/core/optimizer/utils.cc +++ b/onnxruntime/core/optimizer/utils.cc @@ -141,7 +141,11 @@ bool IsAttributeWithExpectedValues(const Node& node, const std::string& attr_nam return true; } -bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector& data) { +bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector& data, bool require_constant) { + if (require_constant && !graph_utils::IsConstantInitializer(graph, input_arg.Name(), true)) { + return false; + } + const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; if (!graph.GetInitializedTensor(input_arg.Name(), tensor_proto)) { return false; diff --git a/onnxruntime/core/optimizer/utils.h b/onnxruntime/core/optimizer/utils.h index 7887bfd360..0cafe0d49c 100644 --- a/onnxruntime/core/optimizer/utils.h +++ b/onnxruntime/core/optimizer/utils.h @@ -42,7 +42,7 @@ bool IsAttributeWithExpectedValues(const Node& node, const std::string& attr_nam /** Get values of an integer tensor from initializer, and append them to a vector. @remarks only support int32 and int64 tensor. This function does not clear vector before appending. */ -bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector& data); +bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, std::vector& data, bool require_constant = true); /** Check Shape of node input or output. @remarks when expected dim value > 0, the dim is expected to known and match the dim value. diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 7ff3f75e3d..6aae2911c0 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1116,6 +1116,27 @@ TEST_F(GraphTransformationTests, ReshapeFusionInternalReuseTest) { } } + +TEST_F(GraphTransformationTests, ReshapeFusionGraphInputsTest) { + auto model_uri = MODEL_FOLDER "fusion/reshape_fusion_with_graph_inputs.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Shape"], 1); + ASSERT_EQ(op_to_count["Gather"], 1); + ASSERT_EQ(op_to_count["Unsqueeze"], 1); + ASSERT_EQ(op_to_count["Concat"], 1); + ASSERT_EQ(op_to_count["Reshape"], 1); +} + + TEST_F(GraphTransformationTests, ExpandElimination) { auto model_uri = MODEL_FOLDER "expand_elimination.onnx"; std::shared_ptr model; diff --git a/onnxruntime/test/testdata/transform/fusion/reshape_fusion_gen.py b/onnxruntime/test/testdata/transform/fusion/reshape_fusion_gen.py index 3f01ffea8f..156d9a2147 100644 --- a/onnxruntime/test/testdata/transform/fusion/reshape_fusion_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/reshape_fusion_gen.py @@ -70,3 +70,30 @@ graph = helper.make_graph( save_model(graph, 'reshape_fusion_internal_node_is_graph_output.onnx') + + +graph = helper.make_graph( + [ # nodes + helper.make_node("Shape", ["query"], ["shape0_out"], "shape0"), + helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0), + helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), + helper.make_node("Concat", ["a", "unsqueeze0_out"], ["concat_out"], "concat", axis=0), + helper.make_node("Reshape", ["doc_word_mask", "concat_out"], ["Result"], "reshape"), + ], + "Reshape_Fusion", #name + [ # inputs + helper.make_tensor_value_info('query', TensorProto.FLOAT, [1, 50]), + helper.make_tensor_value_info('doc_word_mask', TensorProto.FLOAT, [1, 200, 50]), + ], + [ # outputs + helper.make_tensor_value_info('Result', TensorProto.FLOAT, [10, 20, 'unk']), + ], + [ # initializers + helper.make_tensor('a', TensorProto.INT64, [1], [-1]), + helper.make_tensor('indices0', TensorProto.INT64, [], [1]), + ] +) + +save_model(graph, 'reshape_fusion_with_graph_inputs.onnx') + + diff --git a/onnxruntime/test/testdata/transform/fusion/reshape_fusion_with_graph_inputs.onnx b/onnxruntime/test/testdata/transform/fusion/reshape_fusion_with_graph_inputs.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e3609a566defdc46008d3c9b724e6cf366208cd9 GIT binary patch literal 446 zcmZ8d%TB{U3}h3MEEAfREeO;L(uz|~6%v<<0|W`~ATGT`OR}O-A%SdGJidmX;2-!a z*yaJDmMq!hu{}0*XP-7WoM$b|HwLv3PjVU;x|+C6%$=JRhI90apjn<~iBwbO--~J? z%cjR`6YgHsXy{{8yk_cPdcJra0()@*2s_) zc_OMtbrW;CAy_QNpR1>e1_2t|%!0L1sv_X8SRaWHT zn39n+;_G4uWIsR6brVH6f6#*gPmYbw|nJBBc*2{!$zS$BtM1+hJt Tg=}cbFun_i**afF*2w<>5r%(D literal 0 HcmV?d00001