mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Fix Reshape check (#16349)
### Fix Reshape check
3D->2D reshape by merging the first dims.
There is a bug for the case.
```mermaid
stateDiagram
[768,12,64] --> Reshape
(—1,768) --> Reshape
Reshape --> [768,768]
```
The Reshape pass the upstream Reshape check, but it should not.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
097346be9d
commit
574e17ade4
2 changed files with 78 additions and 12 deletions
|
|
@ -246,21 +246,21 @@ std::optional<ReshapeInfo> UpStreamReshapeGraphTransformer::IsSupportedForUpstre
|
|||
return std::nullopt;
|
||||
}
|
||||
|
||||
bool are_first_two_dims_concrete = utils::HasDimValue(data_shape->dim(0)) && utils::HasDimValue(data_shape->dim(1));
|
||||
int64_t merged_dims_value = are_first_two_dims_concrete
|
||||
? data_shape->dim(0).dim_value() * data_shape->dim(1).dim_value()
|
||||
: -1;
|
||||
|
||||
InlinedVector<int64_t> new_shape_const_values;
|
||||
optimizer_utils::AppendTensorFromInitializer(graph, *node.InputDefs()[1], new_shape_const_values, true);
|
||||
if (new_shape_const_values.size() != 2 ||
|
||||
!(new_shape_const_values[0] == -1 || new_shape_const_values[0] == merged_dims_value)) {
|
||||
LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() + " due to target shape is not merging first two dims.");
|
||||
if (!utils::HasDimValue(data_shape->dim(2))) {
|
||||
LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() + " due to data shape's last dim is not concrete.");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (!utils::HasDimValue(data_shape->dim(2))) {
|
||||
LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() + " due to the last dim size is not concrete value.");
|
||||
InlinedVector<int64_t> new_shape_const_values;
|
||||
optimizer_utils::AppendTensorFromInitializer(graph, *node.InputDefs()[1], new_shape_const_values, true);
|
||||
if (new_shape_const_values.size() != 2) {
|
||||
LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() + " due to target shape is rank 2.");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (new_shape_const_values[1] != data_shape->dim(2).dim_value()) {
|
||||
LOG_DEBUG_INFO(logger, "Skip Reshape node " + node.Name() +
|
||||
" due to target shape's last dim is not equal to data shape's last dim.");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1831,6 +1831,72 @@ TEST(ComputeOptimizerTests, ReshapeElementwiseOps_NoPropagation1) {
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Test graph include multiple equivalent subgraphs as below.
|
||||
graph input [128, 4, 32] (int64_t)
|
||||
|
|
||||
Cast initializer value: (-1, 128)
|
||||
| /
|
||||
Reshape
|
||||
|
|
||||
Identity
|
||||
|
|
||||
graph out [128, 128] (int64_t)
|
||||
|
||||
Add an Identity node because currently we don't allow Reshape generate graph output.
|
||||
*/
|
||||
TEST(ComputeOptimizerTests, ReshapeElementwiseOps_NoPropagation2) {
|
||||
const logging::Logger* logger = &logging::LoggingManager::DefaultLogger();
|
||||
auto pre_graph_checker = [](Graph& graph) -> Status {
|
||||
auto op_count_pre = CountOpsInGraph(graph);
|
||||
TEST_RETURN_IF_NOT(op_count_pre.size() == 3U);
|
||||
TEST_RETURN_IF_NOT(op_count_pre["Cast"] == 1);
|
||||
TEST_RETURN_IF_NOT(op_count_pre["Reshape"] == 1);
|
||||
TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
auto post_graph_checker = [](Graph& graph) {
|
||||
auto op_count_post = CountOpsInGraph(graph);
|
||||
TEST_RETURN_IF_NOT(op_count_post.size() == 3U);
|
||||
TEST_RETURN_IF_NOT(op_count_post["Cast"] == 1);
|
||||
TEST_RETURN_IF_NOT(op_count_post["Reshape"] == 1);
|
||||
TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1);
|
||||
|
||||
for (Node& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Reshape") {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
auto producer_node = graph.GetProducerNode(input_defs[0]->Name());
|
||||
TEST_RETURN_IF_NOT(producer_node != nullptr);
|
||||
TEST_RETURN_IF_NOT(producer_node->OpType() == "Cast");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
auto build_test_case = [](ModelTestBuilder& builder) {
|
||||
auto* input1_arg = builder.MakeInput<int64_t>({{128, 4, 32}});
|
||||
auto* cast_out = builder.MakeIntermediate();
|
||||
builder.AddNode("Cast", {input1_arg}, {cast_out})
|
||||
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_INT64));
|
||||
|
||||
auto* shape_initializer = builder.MakeInitializer<int64_t>({2}, {-1, 128});
|
||||
auto* reshape_out = builder.MakeIntermediate();
|
||||
builder.AddNode("Reshape", {cast_out, shape_initializer}, {reshape_out});
|
||||
|
||||
auto* identity_out = builder.MakeOutput();
|
||||
builder.AddNode("Identity", {reshape_out}, {identity_out});
|
||||
};
|
||||
|
||||
const std::vector<int> opsets{12, 13, 14};
|
||||
for (auto& opset_version : opsets) {
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<UpStreamReshapeGraphTransformer>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger, std::move(transformer),
|
||||
TransformerLevel::Level1,
|
||||
1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Test graph include multiple equivalent subgraphs as below.
|
||||
graph input [4, 32, 256] (int64_t) graph input () (scalar, int64_t)
|
||||
|
|
|
|||
Loading…
Reference in a new issue