mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
(MaximKalininMS) Fix Reshape Fusion and Crash in Reshape (#3777)
* Fix a crash in Reshape Reshape doesn't handle 0 input dimension properly, which leads to a division by zero * Fix reshape fusion https://github.com/microsoft/onnxruntime/pull/3554 introduced a bug: initializers can now come before Shape->Gather->Unsqueeze chains; if those initializers have more than 1 element, expected dimensions in the chains are now incorrect. Authored-by: Max Kalinin <makalini@microsoft.com>
This commit is contained in:
parent
15eca74d15
commit
3fab8ebfe9
9 changed files with 204 additions and 3 deletions
|
|
@ -121,7 +121,7 @@ bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::L
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(gather.InputDefs()[1]), int64_t(i), false)) {
|
||||
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(gather.InputDefs()[1]), int64_t(shape_value.size()), false)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class ReshapeHelper {
|
|||
|
||||
if (unknown_dim != -1) {
|
||||
// calculate unknown dimension
|
||||
ORT_ENFORCE((input_shape.Size() % size) == 0,
|
||||
ORT_ENFORCE(size != 0 && (input_shape.Size() % size) == 0,
|
||||
"The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, ", requested shape:", TensorShape(requested_shape));
|
||||
requested_shape[unknown_dim] = input_shape.Size() / size;
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -1153,6 +1153,91 @@ TEST_F(GraphTransformationTests, ReshapeFusionGraphInputsTest) {
|
|||
ASSERT_EQ(op_to_count["Concat"], 1);
|
||||
ASSERT_EQ(op_to_count["Reshape"], 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, ReshapeFusionMultipleValuesInInitializerDoesntApplyTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/reshape_fusion_multiple_values_in_initializer_tensor_1.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
std::map<std::string, int> op_to_count_orig = CountOpsInGraph(graph);
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<ReshapeFusion>(), TransformerLevel::Level1);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
// The optimization does not apply.
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count_orig, op_to_count);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, ReshapeFusionMultipleValuesInInitializerAppliesTest) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/reshape_fusion_multiple_values_in_initializer_tensor_2.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<ReshapeFusion>(), TransformerLevel::Level1);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Shape"], 0);
|
||||
ASSERT_EQ(op_to_count["Gather"], 0);
|
||||
ASSERT_EQ(op_to_count["Unsqueeze"], 0);
|
||||
ASSERT_EQ(op_to_count["Concat"], 0);
|
||||
ASSERT_EQ(op_to_count["Reshape"], 1);
|
||||
for (const Node& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Reshape") {
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name());
|
||||
ASSERT_TRUE(tensor_proto != nullptr);
|
||||
|
||||
auto initializer = onnxruntime::make_unique<Initializer>(*tensor_proto, graph.ModelPath());
|
||||
EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
EXPECT_EQ(initializer->size(), 3);
|
||||
|
||||
const int64_t* val = initializer->data<int64_t>();
|
||||
EXPECT_EQ(val[0], 1);
|
||||
EXPECT_EQ(val[1], 200);
|
||||
EXPECT_EQ(val[2], 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, ReshapeFusionAnotherGraphInput) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/reshape_fusion_input_is_graph_input.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
std::map<std::string, int> op_to_count_orig = CountOpsInGraph(graph);
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<ReshapeFusion>(), TransformerLevel::Level1);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
// The optimization does not apply.
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count_orig, op_to_count);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, ReshapeFusionOverridableInitializer) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/reshape_fusion_overridable_initializer.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
std::map<std::string, int> op_to_count_orig = CountOpsInGraph(graph);
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<ReshapeFusion>(), TransformerLevel::Level1);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
// The optimization does not apply.
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count_orig, op_to_count);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, ExpandElimination) {
|
||||
auto model_uri = MODEL_FOLDER "expand_elimination.onnx";
|
||||
|
|
|
|||
|
|
@ -33,6 +33,32 @@ TEST(TensorOpTest, ReshapeWithEmptyDim) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT doesn't support empty dimension
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, ReshapeWithEmptyInput) {
|
||||
OpTester test("Reshape");
|
||||
test.AddInput<float>("data", {0, 10}, std::vector<float>());
|
||||
test.AddInput<int64_t>("shape", {3}, {0, 10, 1}, false);
|
||||
test.AddOutput<float>("reshaped", {0, 10, 1}, std::vector<float>());
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT doesn't support empty dimension
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, ReshapeWithEmptyInputAndDynamicShape) {
|
||||
{
|
||||
OpTester test("Reshape");
|
||||
test.AddInput<float>("data", {1, 0}, std::vector<float>());
|
||||
test.AddInput<int64_t>("shape", {3}, {1, 0, -1}, false);
|
||||
test.AddOutput<float>("reshaped", {1, 0, 1}, {});
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure, "The input tensor cannot be reshaped to the requested shape", {kTensorrtExecutionProvider}); // TensorRT doesn't support empty dimension
|
||||
}
|
||||
|
||||
{
|
||||
OpTester test("Reshape");
|
||||
test.AddInput<float>("data", {1, 0}, std::vector<float>());
|
||||
test.AddInput<int64_t>("shape", {3}, {1, 1, -1}, false);
|
||||
test.AddOutput<float>("reshaped", {1, 1, 0}, {});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT doesn't support empty dimension
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, ReshapeWithInitializer) {
|
||||
OpTester test("Reshape");
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,98 @@ graph = helper.make_graph(
|
|||
|
||||
save_model(graph, 'reshape_fusion_internal_node_is_graph_output.onnx')
|
||||
|
||||
graph = helper.make_graph(
|
||||
[ # nodes
|
||||
helper.make_node("Shape", ["SubgraphRoot"], ["shape2_out"], "shape2"),
|
||||
helper.make_node("Gather", ["shape2_out", "indices2"], ["gather2_out"], "gather2", axis=0),
|
||||
helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
|
||||
|
||||
helper.make_node("Concat", ["a", "unsqueeze2_out"], ["concat_out"], "concat", axis=0),
|
||||
helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"),
|
||||
],
|
||||
"Reshape_Fusion", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('Result', TensorProto.FLOAT, ['unk_0', 'unk_1', 'unk_2']),
|
||||
],
|
||||
[ # initializers
|
||||
helper.make_tensor('a', TensorProto.INT64, [2], [1, 200]),
|
||||
helper.make_tensor('indices2', TensorProto.INT64, [], [1]),
|
||||
]
|
||||
)
|
||||
|
||||
save_model(graph, 'reshape_fusion_multiple_values_in_initializer_tensor_1.onnx')
|
||||
|
||||
graph = helper.make_graph(
|
||||
[ # nodes
|
||||
helper.make_node("Shape", ["SubgraphRoot"], ["shape2_out"], "shape2"),
|
||||
helper.make_node("Gather", ["shape2_out", "indices2"], ["gather2_out"], "gather2", axis=0),
|
||||
helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
|
||||
|
||||
helper.make_node("Concat", ["a", "unsqueeze2_out"], ["concat_out"], "concat", axis=0),
|
||||
helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"),
|
||||
],
|
||||
"Reshape_Fusion", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('Result', TensorProto.FLOAT, ['unk_0', 'unk_1', 'unk_2']),
|
||||
],
|
||||
[ # initializers
|
||||
helper.make_tensor('a', TensorProto.INT64, [2], [1, 200]),
|
||||
helper.make_tensor('indices2', TensorProto.INT64, [], [2]),
|
||||
]
|
||||
)
|
||||
|
||||
save_model(graph, 'reshape_fusion_multiple_values_in_initializer_tensor_2.onnx')
|
||||
|
||||
graph = helper.make_graph(
|
||||
[ # nodes
|
||||
helper.make_node("Shape", ["AnotherInput"], ["shape2_out"], "shape2"),
|
||||
helper.make_node("Gather", ["shape2_out", "indices2"], ["gather2_out"], "gather2", axis=0),
|
||||
helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
|
||||
|
||||
helper.make_node("Concat", ["a", "unsqueeze2_out"], ["concat_out"], "concat", axis=0),
|
||||
helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"),
|
||||
],
|
||||
"Reshape_Fusion", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]),
|
||||
helper.make_tensor_value_info('AnotherInput', TensorProto.FLOAT, ['input_dim_0', 'input_dim_1', 'input_dim_2']),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('Result', TensorProto.FLOAT, ['unk_0', 'unk_1', 'unk_2']),
|
||||
],
|
||||
[ # initializers
|
||||
helper.make_tensor('a', TensorProto.INT64, [2], [1, 200]),
|
||||
helper.make_tensor('indices2', TensorProto.INT64, [], [2]),
|
||||
]
|
||||
)
|
||||
|
||||
save_model(graph, 'reshape_fusion_input_is_graph_input.onnx')
|
||||
|
||||
graph = helper.make_graph(
|
||||
[ # nodes
|
||||
helper.make_node("Concat", ["a"], ["concat_out"], "concat", axis=0),
|
||||
helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"),
|
||||
],
|
||||
"Reshape_Fusion", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [2, 3, 4]),
|
||||
helper.make_tensor_value_info('a', TensorProto.INT64, [3]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('Result', TensorProto.FLOAT, ['unk_0', 'unk_1', 'unk_2']),
|
||||
],
|
||||
[ # initializers
|
||||
helper.make_tensor('a', TensorProto.INT64, [3], [1, 1, 2*3*4]),
|
||||
]
|
||||
)
|
||||
|
||||
save_model(graph, 'reshape_fusion_overridable_initializer.onnx')
|
||||
|
||||
graph = helper.make_graph(
|
||||
[ # nodes
|
||||
|
|
@ -96,4 +187,3 @@ graph = helper.make_graph(
|
|||
|
||||
save_model(graph, 'reshape_fusion_with_graph_inputs.onnx')
|
||||
|
||||
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/reshape_fusion_input_is_graph_input.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/reshape_fusion_input_is_graph_input.onnx
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/reshape_fusion_overridable_initializer.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/reshape_fusion_overridable_initializer.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue