Fix Reshape check (#16349)

### Fix Reshape check

3D->2D reshape by merging the first dims. 

There is a bug for the case. 

```mermaid
stateDiagram
    [768,12,64] --> Reshape
    (—1,768) --> Reshape
   Reshape --> [768,768]
```
   


The Reshape pass the upstream Reshape check, but it should not. 

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
pengwa 2023-06-15 13:50:53 +08:00 committed by GitHub
parent 097346be9d
commit 574e17ade4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 78 additions and 12 deletions

View file

@ -246,21 +246,21 @@ std::optional<ReshapeInfo> UpStreamReshapeGraphTransformer::IsSupportedForUpstre
return std::nullopt;
}
bool are_first_two_dims_concrete = utils::HasDimValue(data_shape->dim(0)) && utils::HasDimValue(data_shape->dim(1));
int64_t merged_dims_value = are_first_two_dims_concrete
? data_shape->dim(0).dim_value() * data_shape->dim(1).dim_value()
: -1;
InlinedVector<int64_t> new_shape_const_values;
optimizer_utils::AppendTensorFromInitializer(graph, *node.InputDefs()[1], new_shape_const_values, true);
if (new_shape_const_values.size() != 2 ||
!(new_shape_const_values[0] == -1 || new_shape_const_values[0] == merged_dims_value)) {
LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() + " due to target shape is not merging first two dims.");
if (!utils::HasDimValue(data_shape->dim(2))) {
LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() + " due to data shape's last dim is not concrete.");
return std::nullopt;
}
if (!utils::HasDimValue(data_shape->dim(2))) {
LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() + " due to the last dim size is not concrete value.");
InlinedVector<int64_t> new_shape_const_values;
optimizer_utils::AppendTensorFromInitializer(graph, *node.InputDefs()[1], new_shape_const_values, true);
if (new_shape_const_values.size() != 2) {
LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() + " due to target shape is rank 2.");
return std::nullopt;
}
if (new_shape_const_values[1] != data_shape->dim(2).dim_value()) {
LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() +
" due to target shape's last dim is not equal to data shape's last dim.");
return std::nullopt;
}

View file

@ -1831,6 +1831,72 @@ TEST(ComputeOptimizerTests, ReshapeElementwiseOps_NoPropagation1) {
}
}
/*
Test graph include multiple equivalent subgraphs as below.
graph input [128, 4, 32] (int64_t)
|
Cast initializer value: (-1, 128)
| /
Reshape
|
Identity
|
graph out [128, 128] (int64_t)
Add an Identity node because currently we don't allow Reshape generate graph output.
*/
TEST(ComputeOptimizerTests, ReshapeElementwiseOps_NoPropagation2) {
const logging::Logger* logger = &logging::LoggingManager::DefaultLogger();
auto pre_graph_checker = [](Graph& graph) -> Status {
auto op_count_pre = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_count_pre.size() == 3U);
TEST_RETURN_IF_NOT(op_count_pre["Cast"] == 1);
TEST_RETURN_IF_NOT(op_count_pre["Reshape"] == 1);
TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1);
return Status::OK();
};
auto post_graph_checker = [](Graph& graph) {
auto op_count_post = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_count_post.size() == 3U);
TEST_RETURN_IF_NOT(op_count_post["Cast"] == 1);
TEST_RETURN_IF_NOT(op_count_post["Reshape"] == 1);
TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1);
for (Node& node : graph.Nodes()) {
if (node.OpType() == "Reshape") {
const auto& input_defs = node.InputDefs();
auto producer_node = graph.GetProducerNode(input_defs[0]->Name());
TEST_RETURN_IF_NOT(producer_node != nullptr);
TEST_RETURN_IF_NOT(producer_node->OpType() == "Cast");
}
}
return Status::OK();
};
auto build_test_case = [](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<int64_t>({{128, 4, 32}});
auto* cast_out = builder.MakeIntermediate();
builder.AddNode("Cast", {input1_arg}, {cast_out})
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_INT64));
auto* shape_initializer = builder.MakeInitializer<int64_t>({2}, {-1, 128});
auto* reshape_out = builder.MakeIntermediate();
builder.AddNode("Reshape", {cast_out, shape_initializer}, {reshape_out});
auto* identity_out = builder.MakeOutput();
builder.AddNode("Identity", {reshape_out}, {identity_out});
};
const std::vector<int> opsets{12, 13, 14};
for (auto& opset_version : opsets) {
std::unique_ptr<GraphTransformer> transformer = std::make_unique<UpStreamReshapeGraphTransformer>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger, std::move(transformer),
TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}
}
/*
Test graph include multiple equivalent subgraphs as below.
graph input [4, 32, 256] (int64_t) graph input () (scalar, int64_t)