From 574e17ade41e23f52d24f0c2b6640c006f3bfdcb Mon Sep 17 00:00:00 2001 From: pengwa Date: Thu, 15 Jun 2023 13:50:53 +0800 Subject: [PATCH] Fix Reshape check (#16349) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Fix Reshape check 3D->2D reshape by merging the first dims. There is a bug for the case. ```mermaid stateDiagram [768,12,64] --> Reshape (—1,768) --> Reshape Reshape --> [768,768] ``` The Reshape pass the upstream Reshape check, but it should not. ### Motivation and Context --- .../compute_optimizer/upstream_reshape.cc | 24 +++---- .../test/optimizer/compute_optimizer_test.cc | 66 +++++++++++++++++++ 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc index 483be227fc..f7b48de2ca 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc @@ -246,21 +246,21 @@ std::optional UpStreamReshapeGraphTransformer::IsSupportedForUpstre return std::nullopt; } - bool are_first_two_dims_concrete = utils::HasDimValue(data_shape->dim(0)) && utils::HasDimValue(data_shape->dim(1)); - int64_t merged_dims_value = are_first_two_dims_concrete - ? data_shape->dim(0).dim_value() * data_shape->dim(1).dim_value() - : -1; - - InlinedVector new_shape_const_values; - optimizer_utils::AppendTensorFromInitializer(graph, *node.InputDefs()[1], new_shape_const_values, true); - if (new_shape_const_values.size() != 2 || - !(new_shape_const_values[0] == -1 || new_shape_const_values[0] == merged_dims_value)) { - LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() + " due to target shape is not merging first two dims."); + if (!utils::HasDimValue(data_shape->dim(2))) { + LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() + " due to data shape's last dim is not concrete."); return std::nullopt; } - if (!utils::HasDimValue(data_shape->dim(2))) { - LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() + " due to the last dim size is not concrete value."); + InlinedVector new_shape_const_values; + optimizer_utils::AppendTensorFromInitializer(graph, *node.InputDefs()[1], new_shape_const_values, true); + if (new_shape_const_values.size() != 2) { + LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() + " due to target shape is rank 2."); + return std::nullopt; + } + + if (new_shape_const_values[1] != data_shape->dim(2).dim_value()) { + LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() + + " due to target shape's last dim is not equal to data shape's last dim."); return std::nullopt; } diff --git a/onnxruntime/test/optimizer/compute_optimizer_test.cc b/onnxruntime/test/optimizer/compute_optimizer_test.cc index 9f33ee054e..d374492057 100644 --- a/onnxruntime/test/optimizer/compute_optimizer_test.cc +++ b/onnxruntime/test/optimizer/compute_optimizer_test.cc @@ -1831,6 +1831,72 @@ TEST(ComputeOptimizerTests, ReshapeElementwiseOps_NoPropagation1) { } } +/* +Test graph include multiple equivalent subgraphs as below. + graph input [128, 4, 32] (int64_t) + | + Cast initializer value: (-1, 128) + | / + Reshape + | + Identity + | + graph out [128, 128] (int64_t) + +Add an Identity node because currently we don't allow Reshape generate graph output. +*/ +TEST(ComputeOptimizerTests, ReshapeElementwiseOps_NoPropagation2) { + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto pre_graph_checker = [](Graph& graph) -> Status { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); + TEST_RETURN_IF_NOT(op_count_pre["Cast"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Reshape"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + auto op_count_post = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_post.size() == 3U); + TEST_RETURN_IF_NOT(op_count_post["Cast"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Reshape"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Reshape") { + const auto& input_defs = node.InputDefs(); + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + TEST_RETURN_IF_NOT(producer_node != nullptr); + TEST_RETURN_IF_NOT(producer_node->OpType() == "Cast"); + } + } + return Status::OK(); + }; + + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{128, 4, 32}}); + auto* cast_out = builder.MakeIntermediate(); + builder.AddNode("Cast", {input1_arg}, {cast_out}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_INT64)); + + auto* shape_initializer = builder.MakeInitializer({2}, {-1, 128}); + auto* reshape_out = builder.MakeIntermediate(); + builder.AddNode("Reshape", {cast_out, shape_initializer}, {reshape_out}); + + auto* identity_out = builder.MakeOutput(); + builder.AddNode("Identity", {reshape_out}, {identity_out}); + }; + + const std::vector opsets{12, 13, 14}; + for (auto& opset_version : opsets) { + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger, std::move(transformer), + TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); + } +} + /* Test graph include multiple equivalent subgraphs as below. graph input [4, 32, 256] (int64_t) graph input () (scalar, int64_t)