mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Support older version of slice in reshape fusion (#5574)
* support older version of slice in reshape fusion * fix * review partial comments * add test * add gen file
This commit is contained in:
parent
860cb22260
commit
51af108af5
4 changed files with 82 additions and 9 deletions
|
|
@ -190,16 +190,24 @@ bool ReshapeFusion::Match_One_Element_Output_Subgraph_2(Graph& graph, const Node
|
|||
}
|
||||
|
||||
// Check if Slice op slices 1d array (result of shape) to one element.
|
||||
std::vector<int64_t> slice_inputs;
|
||||
if (optimizer_utils::AppendTensorFromInitializer(graph, *(slice.InputDefs()[1]), slice_inputs, true) &&
|
||||
optimizer_utils::AppendTensorFromInitializer(graph, *(slice.InputDefs()[2]), slice_inputs, true)) {
|
||||
const int64_t slice_start = slice_inputs[0];
|
||||
const int64_t slice_end = slice_inputs[1];
|
||||
if (!(slice_end >= INT_MAX && slice_start == -1) && abs(slice_end - slice_start) != 1) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
std::vector<int64_t> starts_values;
|
||||
std::vector<int64_t> ends_values;
|
||||
if (slice.GetInputEdgesCount() >= 3) {
|
||||
optimizer_utils::AppendTensorFromInitializer(graph, *(slice.InputDefs()[1]), starts_values, true);
|
||||
optimizer_utils::AppendTensorFromInitializer(graph, *(slice.InputDefs()[2]), ends_values, true);
|
||||
} else { // Support older version of Slice node
|
||||
graph_utils::GetRepeatedNodeAttributeValues<int64_t>(slice, "starts", starts_values);
|
||||
graph_utils::GetRepeatedNodeAttributeValues<int64_t>(slice, "ends", ends_values);
|
||||
}
|
||||
if (starts_values.size() != 1 || ends_values.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
int64_t slice_start = starts_values[0];
|
||||
int64_t slice_end = ends_values[0];
|
||||
if (!(slice_end >= INT_MAX && slice_start == -1) && abs(slice_end - slice_start) != 1) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -1467,6 +1467,41 @@ TEST_F(GraphTransformationTests, ReshapeFusionConcatSubgraph) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, ReshapeFusionWithSlice1) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/reshape_fusion_with_slice1.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<ReshapeFusion>(), TransformerLevel::Level1);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Shape"], 0);
|
||||
ASSERT_EQ(op_to_count["Gather"], 0);
|
||||
ASSERT_EQ(op_to_count["Unsqueeze"], 0);
|
||||
ASSERT_EQ(op_to_count["Slice"], 0);
|
||||
ASSERT_EQ(op_to_count["Concat"], 0);
|
||||
ASSERT_EQ(op_to_count["Reshape"], 1);
|
||||
for (const Node& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Reshape") {
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name());
|
||||
ASSERT_TRUE(tensor_proto != nullptr);
|
||||
|
||||
auto initializer = onnxruntime::make_unique<Initializer>(*tensor_proto, graph.ModelPath());
|
||||
EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
EXPECT_EQ(initializer->size(), 3);
|
||||
|
||||
const int64_t* val = initializer->data<int64_t>();
|
||||
EXPECT_EQ(val[0], 0);
|
||||
EXPECT_EQ(val[1], 0);
|
||||
EXPECT_EQ(val[2], -1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, ReshapeFusionConcatSubgraphNotTriggered) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/reshape_fusion_concat_subgraph_not_triggered.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
|
|
|
|||
|
|
@ -219,6 +219,36 @@ graph = helper.make_graph(
|
|||
|
||||
save_model(graph, 'reshape_fusion_concat_subgraph.onnx')
|
||||
|
||||
graph = helper.make_graph(
|
||||
[ # nodes
|
||||
helper.make_node("Shape", ["SubgraphRoot"], ["shape0_out"], "shape0"),
|
||||
helper.make_node("Shape", ["SubgraphRoot"], ["shape1_out"], "shape1"),
|
||||
helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0),
|
||||
helper.make_node("Gather", ["shape1_out", "indices1"], ["gather1_out"], "gather1", axis=0),
|
||||
helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
|
||||
helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
|
||||
|
||||
helper.make_node("Shape", ["SubgraphRoot"], ["shape2_out"], "shape2"),
|
||||
helper.make_node("Slice", ["shape2_out"], ["slice_out"], "slice1", starts = [2], ends = [3]),
|
||||
|
||||
helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out", "slice_out"], ["concat_out"], "concat", axis=0),
|
||||
helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"),
|
||||
],
|
||||
"Reshape_Fusion", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('Result', TensorProto.FLOAT, [10, 20, 'unk']),
|
||||
],
|
||||
[ # initializers
|
||||
helper.make_tensor('indices0', TensorProto.INT64, [], [0]),
|
||||
helper.make_tensor('indices1', TensorProto.INT64, [], [1]),
|
||||
]
|
||||
)
|
||||
# Save this model without checking
|
||||
onnx.save(helper.make_model(graph), 'reshape_fusion_with_slice1.onnx')
|
||||
|
||||
graph = helper.make_graph(
|
||||
[ # nodes
|
||||
helper.make_node("Shape", ["SubgraphRoot"], ["shape0_out"], "shape0"),
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/reshape_fusion_with_slice1.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/reshape_fusion_with_slice1.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue