From f5b4b0f77d7fa625008d0d44c18fc8ef27e4c44a Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Thu, 12 Jan 2023 10:45:15 +0800 Subject: [PATCH] Add support for 'axes' attr of unsqueeze in opset 13 and add ut (#14071) Since opset 13, 'axes' attr of unsqueeze become an input of unsqueeze, add support for it and add ut. --- onnxruntime/core/optimizer/reshape_fusion.cc | 28 ++++- .../test/optimizer/graph_transform_test.cc | 114 ++++++++++++++++++ 2 files changed, 138 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index 6ef1d622d5..9e31dcf083 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -10,6 +10,17 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; namespace onnxruntime { +bool GetAxesFromUnsqueezeNode(const Graph& graph, const Node& unsqueeze, InlinedVector& axes) { + if (graph_utils::MatchesOpSinceVersion(unsqueeze, {1, 11})) { + return graph_utils::GetRepeatedNodeAttributeValues(unsqueeze, "axes", axes); + } else if (graph_utils::MatchesOpSinceVersion(unsqueeze, {13})) { + const NodeArg* axes_node_arg = unsqueeze.InputDefs()[1]; + return optimizer_utils::AppendTensorFromInitializer(graph, *axes_node_arg, axes, true); + } + + return false; +} + Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -145,15 +156,23 @@ bool ReshapeFusion::Match_One_Element_Output_Subgraph_1(Graph& graph, const Node std::vector parent_path{ {0, index, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Shape", {1, 13}, kOnnxDomain}}; + {0, 0, "Shape", {1, 13, 15}, kOnnxDomain}}; std::vector edges; if (graph_utils::FindPath(concat, true, parent_path, edges, logger)) { const Node& unsqueeze = edges[0]->GetNode(); const Node& gather = edges[1]->GetNode(); const Node& shape = edges[2]->GetNode(); + if (graph_utils::MatchesOpSinceVersion(shape, {15})) { + const ONNX_NAMESPACE::AttributeProto* start_attr = graph_utils::GetNodeAttribute(shape, "start"); + const ONNX_NAMESPACE::AttributeProto* end_attr = graph_utils::GetNodeAttribute(shape, "end"); + if (!((!start_attr || static_cast(start_attr->i()) == 0) && (!end_attr))) { + return false; + } + } + InlinedVector axes; - if (!(graph_utils::GetRepeatedNodeAttributeValues(unsqueeze, "axes", axes) && axes.size() == 1 && axes[0] == 0)) { + if (!(GetAxesFromUnsqueezeNode(graph, unsqueeze, axes) && axes.size() == 1 && axes[0] == 0)) { return false; } @@ -275,7 +294,7 @@ bool ReshapeFusion::Is_One_Element_Output_Subgraph(Graph& graph, const NodeArg& graph_utils::FindPath(concat, true, unsqueeze_path, edges, logger)) { const Node& unsqueeze = edges[0]->GetNode(); InlinedVector axes; - if (!(graph_utils::GetRepeatedNodeAttributeValues(unsqueeze, "axes", axes) && axes.size() == 1 && axes[0] == 0)) { + if (!(GetAxesFromUnsqueezeNode(graph, unsqueeze, axes) && axes.size() == 1 && axes[0] == 0)) { return false; } // Unsqueeze_path is found, check for "one-element subgraph -> concat" or "shape -> slice -> squeeze -> @@ -344,7 +363,8 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo } const Node& concat = *p_concat; - if (!graph_utils::IsSupportedOptypeVersionAndDomain(concat, "Concat", {1, 4, 11, 13})) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(concat, "Concat", {1, 4, 11, 13}) && + !graph_utils::IsSupportedOptypeVersionAndDomain(concat, "ConcatTraining", {1}, kMSDomain)) { return false; } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index ac8aad2177..d1d8f435b3 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4312,6 +4312,120 @@ TEST_F(GraphTransformationTests, BitmaskDropoutFusionTest) { TestBitmaskDropoutFusion(MODEL_FOLDER "fusion/bitmask_bias_dropout_fusion_residual.onnx", true, *logger_, 0, 0, 0, 0, 1, 0, 1); } + +/* +This test build a graph like: + input0 input1 + \ / + Add + -----------------| + | | + | Shape + | / \ + | Gather0 Gather1 + | / \ + | Unsqueeze0 Unsqueeze1 (Constant Initializer) (Constant Initializer) + | \ / / / + | \ / / / + | ConcatTraining ------- ------------ + \ / + \ / + Reshape + + +After fusion, the graph become: + input0 input1 + \ / + Add (Constant Initializer) + \ / + Reshape + +*/ +TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) { + constexpr const int batch_size = 64; + constexpr const int seq_lenth = 1024; + constexpr const int hidden_size = 1024; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Shape"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["com.microsoft.ConcatTraining"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Reshape"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Shape"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["com.microsoft.ConcatTraining"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Reshape"] == 1); + return Status::OK(); + }; + + const std::vector opsets{11, 12, 13, 14, 15, 15}; + bool shape_test_for_opset15 = false; + + for (auto& opset_version : opsets) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg0 = builder.MakeInput({{batch_size, seq_lenth, hidden_size}}); + auto* input_arg1 = builder.MakeInput({{hidden_size}}); + auto* scalar_int_0 = builder.MakeInitializer({}, {0}); + auto* scalar_int_1 = builder.MakeInitializer({}, {1}); + auto* single_value_1d_int_0 = builder.MakeInitializer({1}, {0}); + auto* single_value_1d_int_16 = builder.MakeInitializer({1}, {16}); + auto* single_value_1d_int_64 = builder.MakeInitializer({1}, {64}); + auto* add_out = builder.MakeIntermediate(); + auto* shape_out = builder.MakeIntermediate(); + auto* gather_out_0 = builder.MakeIntermediate(); + auto* gather_out_1 = builder.MakeIntermediate(); + auto* unsqueeze_out_0 = builder.MakeIntermediate(); + + auto* unsqueeze_out_1 = builder.MakeIntermediate(); + auto* concattraining1_out = builder.MakeIntermediate(); + auto* concattraining1_length = builder.MakeIntermediate(); + auto* out = builder.MakeOutput(); + + builder.AddNode("Add", {input_arg0, input_arg1}, {add_out}); + if (opset_version == 15) { + if (shape_test_for_opset15) { + auto& shape_1 = builder.AddNode("Shape", {add_out}, {shape_out}); + shape_1.AddAttribute("start", (int64_t)1); + shape_1.AddAttribute("end", (int64_t)2); + } else { + builder.AddNode("Shape", {add_out}, {shape_out}).AddAttribute("start", (int64_t)0); + shape_test_for_opset15 = true; + } + } else { + builder.AddNode("Shape", {add_out}, {shape_out}); + } + builder.AddNode("Gather", {shape_out, scalar_int_0}, {gather_out_0}); + builder.AddNode("Gather", {shape_out, scalar_int_1}, {gather_out_1}); + if (opset_version >= 13) { + builder.AddNode("Unsqueeze", {gather_out_0, single_value_1d_int_0}, {unsqueeze_out_0}); + builder.AddNode("Unsqueeze", {gather_out_1, single_value_1d_int_0}, {unsqueeze_out_1}); + } else { + builder.AddNode("Unsqueeze", {gather_out_0}, {unsqueeze_out_0}).AddAttribute("axes", std::vector{0}); + builder.AddNode("Unsqueeze", {gather_out_1}, {unsqueeze_out_1}).AddAttribute("axes", std::vector{0}); + } + builder.AddNode("ConcatTraining", {unsqueeze_out_0, unsqueeze_out_1, single_value_1d_int_16, single_value_1d_int_64}, + {concattraining1_out, concattraining1_length}, "com.microsoft").AddAttribute("axis", static_cast(0)); + builder.AddNode("Reshape", {add_out, concattraining1_out}, {out}); + }; + + std::unique_ptr transformer = std::make_unique(); + if (opset_version == 15 && shape_test_for_opset15) { + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker)); + } else{ + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } + } +} #endif TEST_F(GraphTransformationTests, LayerNormFusionTest) {