Support older version of slice in reshape fusion (#5574)

* support older version of slice in reshape fusion

* fix

* review partial comments

* add test

* add gen file
This commit is contained in:
Ye Wang 2020-10-24 14:48:18 -07:00 committed by GitHub
parent 860cb22260
commit 51af108af5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 82 additions and 9 deletions

View file

@ -190,16 +190,24 @@ bool ReshapeFusion::Match_One_Element_Output_Subgraph_2(Graph& graph, const Node
}
// Check if Slice op slices 1d array (result of shape) to one element.
std::vector<int64_t> slice_inputs;
if (optimizer_utils::AppendTensorFromInitializer(graph, *(slice.InputDefs()[1]), slice_inputs, true) &&
optimizer_utils::AppendTensorFromInitializer(graph, *(slice.InputDefs()[2]), slice_inputs, true)) {
const int64_t slice_start = slice_inputs[0];
const int64_t slice_end = slice_inputs[1];
if (!(slice_end >= INT_MAX && slice_start == -1) && abs(slice_end - slice_start) != 1) {
return false;
}
return true;
std::vector<int64_t> starts_values;
std::vector<int64_t> ends_values;
if (slice.GetInputEdgesCount() >= 3) {
optimizer_utils::AppendTensorFromInitializer(graph, *(slice.InputDefs()[1]), starts_values, true);
optimizer_utils::AppendTensorFromInitializer(graph, *(slice.InputDefs()[2]), ends_values, true);
} else { // Support older version of Slice node
graph_utils::GetRepeatedNodeAttributeValues<int64_t>(slice, "starts", starts_values);
graph_utils::GetRepeatedNodeAttributeValues<int64_t>(slice, "ends", ends_values);
}
if (starts_values.size() != 1 || ends_values.size() != 1) {
return false;
}
int64_t slice_start = starts_values[0];
int64_t slice_end = ends_values[0];
if (!(slice_end >= INT_MAX && slice_start == -1) && abs(slice_end - slice_start) != 1) {
return false;
}
return true;
}
return false;

View file

@ -1467,6 +1467,41 @@ TEST_F(GraphTransformationTests, ReshapeFusionConcatSubgraph) {
}
}
TEST_F(GraphTransformationTests, ReshapeFusionWithSlice1) {
auto model_uri = MODEL_FOLDER "fusion/reshape_fusion_with_slice1.onnx";
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<ReshapeFusion>(), TransformerLevel::Level1);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Shape"], 0);
ASSERT_EQ(op_to_count["Gather"], 0);
ASSERT_EQ(op_to_count["Unsqueeze"], 0);
ASSERT_EQ(op_to_count["Slice"], 0);
ASSERT_EQ(op_to_count["Concat"], 0);
ASSERT_EQ(op_to_count["Reshape"], 1);
for (const Node& node : graph.Nodes()) {
if (node.OpType() == "Reshape") {
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name());
ASSERT_TRUE(tensor_proto != nullptr);
auto initializer = onnxruntime::make_unique<Initializer>(*tensor_proto, graph.ModelPath());
EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64);
EXPECT_EQ(initializer->size(), 3);
const int64_t* val = initializer->data<int64_t>();
EXPECT_EQ(val[0], 0);
EXPECT_EQ(val[1], 0);
EXPECT_EQ(val[2], -1);
}
}
}
TEST_F(GraphTransformationTests, ReshapeFusionConcatSubgraphNotTriggered) {
auto model_uri = MODEL_FOLDER "fusion/reshape_fusion_concat_subgraph_not_triggered.onnx";
std::shared_ptr<Model> p_model;

View file

@ -219,6 +219,36 @@ graph = helper.make_graph(
save_model(graph, 'reshape_fusion_concat_subgraph.onnx')
graph = helper.make_graph(
[ # nodes
helper.make_node("Shape", ["SubgraphRoot"], ["shape0_out"], "shape0"),
helper.make_node("Shape", ["SubgraphRoot"], ["shape1_out"], "shape1"),
helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0),
helper.make_node("Gather", ["shape1_out", "indices1"], ["gather1_out"], "gather1", axis=0),
helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
helper.make_node("Shape", ["SubgraphRoot"], ["shape2_out"], "shape2"),
helper.make_node("Slice", ["shape2_out"], ["slice_out"], "slice1", starts = [2], ends = [3]),
helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out", "slice_out"], ["concat_out"], "concat", axis=0),
helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"),
],
"Reshape_Fusion", #name
[ # inputs
helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]),
],
[ # outputs
helper.make_tensor_value_info('Result', TensorProto.FLOAT, [10, 20, 'unk']),
],
[ # initializers
helper.make_tensor('indices0', TensorProto.INT64, [], [0]),
helper.make_tensor('indices1', TensorProto.INT64, [], [1]),
]
)
# Save this model without checking
onnx.save(helper.make_model(graph), 'reshape_fusion_with_slice1.onnx')
graph = helper.make_graph(
[ # nodes
helper.make_node("Shape", ["SubgraphRoot"], ["shape0_out"], "shape0"),