From 8bc275e93fa31c9698110a1f1d9665a237481be7 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Thu, 18 Mar 2021 11:41:58 -0700 Subject: [PATCH] Enhance Transpose, Cast and MatMul fusion when Cast and/or Fusion feeds multiple nodes. (#7021) * Added new Transpose+Cast+MatMul => Cast+FusedMatMul test scenarios. * The Cast node may feed more than one node. * Transpose node may feed multiple nodes and still may be fused with MatMul nodes. --- .../core/optimizer/matmul_transpose_fusion.cc | 35 +++--- .../test/optimizer/graph_transform_test.cc | 13 +- .../transpose_cast_matmul_4d_fusion0.onnx | Bin 415 -> 454 bytes .../transpose_cast_matmul_4d_fusion1.onnx | Bin 415 -> 454 bytes .../transpose_cast_matmul_4d_fusion2.onnx | Bin 0 -> 561 bytes .../transpose_cast_matmul_4d_fusion3.onnx | Bin 0 -> 683 bytes .../transpose_cast_matmul_4d_fusion4.onnx | Bin 0 -> 669 bytes .../transpose_cast_matmul_4d_fusion5.onnx | Bin 0 -> 669 bytes .../transform/fusion/transpose_matmul_gen.py | 119 ++++++++++++------ 9 files changed, 113 insertions(+), 54 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion2.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion3.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion4.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion5.onnx diff --git a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc index bdffc7701d..6fac01a033 100644 --- a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc @@ -86,11 +86,9 @@ static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_ } } -/* GetTransposeNodeFromCast: Interchange Cast and Transpose nodes in the graph and return Transpose node if possible -* Requirements to interchange Cast and Transpose nodes changing the order of the operations. -* 1. Both Cast and Transpose are one-output nodes (assuming both have one-input only) -* 2. Transpose only feeds the Cast node (and no other node) -* 3. Cast only feeds the MalMul node (and no other node) +/* ReorderCastAndTranspose: +* Interchange Cast and Transpose nodes in the graph and return the new Transpose node if possible else nullptr. +* * * Transform the following pattern * | @@ -118,11 +116,13 @@ static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_ * | * V */ -static Node* GetTransposeNodeFromCast(Graph& graph, Node* cast) { +static Node* ReorderCastAndTranspose(Graph& graph, Node* cast, + std::unordered_map& consumer_count, + std::deque& removed_nodes) { ORT_ENFORCE(cast != nullptr); auto transpose = GetTransposeNodeFromOutput(graph, *cast->MutableInputDefs()[0]); - if (transpose == nullptr || cast->GetOutputEdgesCount() != 1 || transpose->GetOutputEdgesCount() != 1) { + if (transpose == nullptr) { return nullptr; } NodeArg* cast_output = cast->MutableOutputDefs()[0]; @@ -159,10 +159,12 @@ static Node* GetTransposeNodeFromCast(Graph& graph, Node* cast) { &transpose->GetAttributes(), transpose->Domain()); + size_t consumers = UpdateConsumerCount(graph, transpose->MutableOutputDefs()[0], consumer_count); graph_utils::RemoveNodeOutputEdges(graph, *cast); - graph_utils::RemoveNodeOutputEdges(graph, *transpose); graph.RemoveNode(cast->Index()); - graph.RemoveNode(transpose->Index()); + if (consumers == 0) { + removed_nodes.push_front(transpose->Index()); + } return &new_transpose; } @@ -280,16 +282,17 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ NodeArg* right_input = node.MutableInputDefs()[1]; auto right = GetTransposeNodeFromOutput(graph, *right_input); - if (!left && !right) { + if (!left) { Node* left_node = graph.GetMutableProducerNode(left_input->Name()); if (left_node && left_node->OpType() == "Cast") { - left = GetTransposeNodeFromCast(graph, left_node); + left = ReorderCastAndTranspose(graph, left_node, consumer_count, removed_nodes); } - if (!left) { - Node* right_node = graph.GetMutableProducerNode(right_input->Name()); - if (right_node && right_node->OpType() == "Cast") { - right = GetTransposeNodeFromCast(graph, right_node); - } + } + + if (!right) { + Node* right_node = graph.GetMutableProducerNode(right_input->Name()); + if (right_node && right_node->OpType() == "Cast") { + right = ReorderCastAndTranspose(graph, right_node, consumer_count, removed_nodes); } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 1c88e15e6e..049616570b 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -847,12 +847,19 @@ TEST_F(GraphTransformationTests, TransposeMatmulFusion) { TEST_F(GraphTransformationTests, TransposeCastMatmulFusion) { const std::vector model_uris = { MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion0.onnx", // Test fusion from the right input - MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion1.onnx" // Test fusion from the left input + MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion1.onnx", // Test fusion from the left input + MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion2.onnx", // Test fusion both from the left and right inputs + MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion3.onnx", // Cast nodes feed multiple MatMul nodes. + MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion4.onnx", // Cast nodes feed one MatMul node and + // the Transpose nodes feed another MatMul node. + MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion5.onnx" // One Cast node and one Transpose node feed each + // MatMul nodes. }; for (const auto& model_uri : model_uris) { std::shared_ptr p_model; ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK()); Graph& graph = p_model->MainGraph(); + std::map orig_op_to_count = CountOpsInGraph(graph); // Original op count onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); @@ -861,8 +868,8 @@ TEST_F(GraphTransformationTests, TransposeCastMatmulFusion) { std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Transpose"] == 0); ASSERT_TRUE(op_to_count["MatMul"] == 0); - ASSERT_TRUE(op_to_count["Cast"] == 2); - ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 1); + ASSERT_TRUE(op_to_count["Cast"] == orig_op_to_count["Cast"]); + ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == orig_op_to_count["MatMul"]); } } diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion0.onnx b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion0.onnx index 8d5c73c82fc9957d9049111e9ebe23d30738aa12..d2d4b7b32e3a4f63fb22a14f84bc7d82ccc99436 100644 GIT binary patch delta 160 zcmbQwe2jU5sLnKIE)OpD%)Elql6V6lp^~D+yyAlV;?$IQsF)OY2uy5Zf+C~G#3p@9 zDK_WCVxW2@77(q)$;DKXFTu5dk%`L)W|pB4Uoub*ZjK>X&qRGuM)!$FOgW@Dd=pE2 eOLHcEw`23;Viscbn{3S}4lGK!VC|`=r8OV<}P+|emTAW-=CHWFu z3mBQWJYaeZg@j6q67z}+@{8g645he3U}Eux6V2pBJ-D!_nwTY}YY4T%5Yq~GE;QG0 sP0W^@_>P-Hio-Xt#J4nO;&(eXKQ3k=M!(6{jPgJ-no$)_C&S7$78xWN30?aS4*yY|cPF&{!4_ zt;Na3RFW^jwSbWc;U+^YZZbr6lOY*yGQ{mBLn(-x3^Cml#D&`@xTOq*IPyzNz`-lU z;hR|ETbcv(FdKv=Bn%Hvh(HK1PNI|%5hB4Q#=$5g#KpwH48%-a%tDNQU;?4YPza}@ oBo!{WP0r8N%gszK$}i4OD-mD>0L99m;Q#;t literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion3.onnx b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion3.onnx new file mode 100644 index 0000000000000000000000000000000000000000..518d7f89d90122e6cc0d35689f170dc149c0b337 GIT binary patch literal 683 zcmbu6u};G<5Qb|fYTN^))JUjg1g1)$-pUk}g{2)(CzdFLAVnf}6kjABr|-#vN!lb5 z$wCL)cjsT9|MLk?`dX>4vQm`g={==}9=~091}{@ZMbS!1Dr02|ZPNGXPz%m>$6BV9$8x2?so3-xL1B z6OMbrpz(x%J-xtqC*xNrg|4Xn-au?(+Ekym;b{}WsZ4B{cx1}Uc1F!3J3Hss+iu7U zxDrfC;KH$sgCp6t-Z~M=QPJBSv<=lN7((?_4Qo)fc46Y5ujg=)==>qil2RM}K|cEd D%5cM- literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion4.onnx b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion4.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e2d295c288ba637b44a95702e75012b64c322787 GIT binary patch literal 669 zcmbV}u};G<5Qb|fYTN^)^eUl}5tu50>Xj)f3|-2AI_n(+o=7t^kOHw{nQ2p((`v($Qu>?ra`!9}-~7P~i_ zeZK0~x`cDV5P=KF91ae$$L8uq5=TWZ*U+`lsKACAsQzGWT$uQ$n>m~%I={`cq|}vu HBcFW%V9341 literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion5.onnx b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion5.onnx new file mode 100644 index 0000000000000000000000000000000000000000..69481810f2f792da1596c23575979a62de867774 GIT binary patch literal 669 zcmbV}PfNov7{-^htl6tVmx0V4)zd)fd-YW2MHut6lb2F!g@R2=lY$>-U(oMh+Pbzv z>cK<)yy3|&&y$yV+?PUpR<+Q1k=^M`>4z^Dp1{jRS?f6DTvv%urBvBl{BstZ9ZqL) zc)jS0%@W))tBPnAk;sYYk$8q=0r0C4G)!l3hnzdS+XX=h6Zlj)E%S?l&+1GZ1}zWy*t=iSc0CNGATcLo^tFg(Hds_#li zKbX$G+zf49!zBaGfilMo3JxdF?KFsxHzHyaz1~8fq*Yl_Bh??QtqU{%e7l5;R2KI| Lo>o%HckPpJGseBd literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py b/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py index 2f21d4917e..dda0efb108 100644 --- a/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py @@ -129,46 +129,95 @@ gen_with_preserved_transpose( def gen_transpose_fusion_with_cast(model_path): - nodes = [ - helper.make_node( - "Cast", - ["input_1"], - ["casted_input_1"], - to = 10 - ), - helper.make_node( - "Transpose", - ["input_0"], - ["transposed_input_0"], - perm = [0, 1, 3, 2]), - helper.make_node( - "Cast", - ["transposed_input_0"], - ["transposed_casted_input_0"], - to = 10), - helper.make_node( - "MatMul", - ["transposed_casted_input_0", "casted_input_1"], - ["output_0"]) - ] + cast_1 = helper.make_node( + "Cast", + ["input_1"], + ["casted_input_1"], + "Cast_1", + to = TensorProto.FLOAT16) + transpose_0 = helper.make_node( + "Transpose", + ["input_0"], + ["transposed_input_0"], + "Transpose_0", + perm = [0, 1, 3, 2]) + cast_0 = helper.make_node( + "Cast", + ["transposed_input_0"], + ["transposed_casted_input_0"], + "Cast_0", + to = TensorProto.FLOAT16) + matmul_0 = helper.make_node( + "MatMul", + ["transposed_casted_input_0", "casted_input_1"], + ["output_0"], + "MatMul_0") - inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, [3, 2, 'K', 'M']), - helper.make_tensor_value_info( - "input_1", TensorProto.FLOAT, [3, 2, 'K', 'N']) - ] - - outputs = [ - helper.make_tensor_value_info( - "output_0", TensorProto.FLOAT16, [3, 2, 'M', 'N']) - ] + nodes = [transpose_0, cast_0, cast_1, matmul_0] + input_0 = helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [3, 2, 'N', 'N']) + input_1 = helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [3, 2, 'N', 'N']) + inputs = [input_0, input_1] + output_0 = helper.make_tensor_value_info("output_0", TensorProto.FLOAT16, [3, 2, 'N', 'N']) + outputs = [output_0] + # Testcase0: First input of MatMul is transposed save(model_path + "0.onnx", nodes, inputs, outputs, []) - # Re-arragne nodes so that the transpose is on left input of matmul - nodes = nodes[1:3] + nodes[0:1] + nodes[3:] + + # Testcase1: Re-arragne nodes so that the transpose is on second input of matmul + transpose_1 = helper.make_node( + "Transpose", + ["input_1"], + ["transposed_input_1"], + "Transpose_1", + perm = [0, 1, 3, 2]) + cast_1.input[0] = "transposed_input_1" + cast_1.output[0] = "transposed_casted_input_1" + cast_0.input[0] = "input_0" + cast_0.output[0] = "casted_input_0" + matmul_0.input[0] = cast_0.output[0] + matmul_0.input[1] = cast_1.output[0] + nodes = [cast_0, transpose_1, cast_1, matmul_0] save(model_path + "1.onnx", nodes, inputs, outputs, []) + # Testcase2: Create an example with two Cast-ed Transpose-ed inputs feeding a MatMul + cast_0.input[0] = "transposed_input_0" + cast_0.output[0] = "transposed_casted_input_0" + matmul_0.input[0] = cast_0.output[0] + nodes = [transpose_0, cast_0, transpose_1, cast_1, matmul_0] + save(model_path + "2.onnx", nodes, inputs, outputs, []) + + # Testcase3: Create a second MatMul node using the outputs from the same Cast nodes as before + # with each Cast node feeding more than one node. + nodes.append(helper.make_node( + "MatMul", + ["transposed_casted_input_0", "transposed_casted_input_1"], + ["output_1"], + "MatMul_1")) + output_1 = helper.make_tensor_value_info("output_1", TensorProto.FLOAT16, [3, 2, 'N', 'N']) + outputs.append(output_1) + save(model_path + "3.onnx", nodes, inputs, outputs, []) + + # Testcase4: The second MatMul uses transposed inputs without cast. + nodes.pop() + outputs.pop() + matmul_1 = helper.make_node( + "MatMul", + ["transposed_input_0", "transposed_input_1"], + ["output_1"], + "MatMul_1") + nodes.append(matmul_1) + + outputs.append(helper.make_tensor_value_info( + "output_1", TensorProto.FLOAT, [3, 2, 'N', 'N'])) + save(model_path + "4.onnx", nodes, inputs, outputs, []) + + # Testcase5: Each MatMul uses outputs from a Cast and a Transpose + input_0.type.tensor_type.elem_type = TensorProto.FLOAT16 + cast_0.attribute[0].i = TensorProto.FLOAT + matmul_0.input[0] = "transposed_input_0" + matmul_1.input[0] = "transposed_casted_input_0" + output_1.type.tensor_type.elem_type = TensorProto.FLOAT + save(model_path + "5.onnx", nodes, inputs, outputs, []) gen_transpose_fusion_with_cast( "transpose_cast_matmul_4d_fusion")