diff --git a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc index 8b41c417a6..6b4652d3ea 100644 --- a/onnxruntime/core/optimizer/gemm_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_transpose_fusion.cc @@ -26,17 +26,37 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m // check if input A is a Transpose if (A_node_ptr != nullptr && A_node_ptr->OpType() == "Transpose") { - Node& A_node = *graph.GetNode(A_node_ptr->Index()); - transA = !transA; - nodes_to_remove.push_back(A_node); - new_gemm_input_defs[0] = A_node.MutableInputDefs()[0]; + // make sure all consumers are gemm nodes to avoid possible double transpose + std::vector gemm_nodes = graph_utils::FindChildrenByType(*A_node_ptr, "Gemm"); + if (gemm_nodes.size() == A_node_ptr->GetOutputEdgesCount()) { + Node& A_node = *graph.GetNode(A_node_ptr->Index()); + transA = !transA; + if (A_node.GetOutputEdgesCount() > 1) { + // remove only the edge between the Transpose and Gemm nodes, the Transpose won't be removed + // since it's still connected to other Gemm. When transformation for the last connected Gemm is + // being processed, it would fall into the else {} below to remove the Transpose node + int output_idx = graph_utils::GetNodeOutputIndexFromOutputName(A_node, gemm_node.MutableInputDefs()[0]->Name()); + graph.RemoveEdge(A_node.Index(), gemm_node.Index(), output_idx, 0); + } else { + nodes_to_remove.push_back(A_node); + } + new_gemm_input_defs[0] = A_node.MutableInputDefs()[0]; + } } // check if input B is a Transpose if (B_node_ptr != nullptr && B_node_ptr->OpType() == "Transpose") { - Node& B_node = *graph.GetNode(B_node_ptr->Index()); - transB = !transB; - nodes_to_remove.push_back(B_node); - new_gemm_input_defs[1] = B_node.MutableInputDefs()[0]; + std::vector gemm_nodes = graph_utils::FindChildrenByType(*B_node_ptr, "Gemm"); + if (gemm_nodes.size() == B_node_ptr->GetOutputEdgesCount()) { + Node& B_node = *graph.GetNode(B_node_ptr->Index()); + transB = !transB; + if (B_node.GetOutputEdgesCount() > 1) { + int output_idx = graph_utils::GetNodeOutputIndexFromOutputName(B_node, gemm_node.MutableInputDefs()[1]->Name()); + graph.RemoveEdge(B_node.Index(), gemm_node.Index(), output_idx, 1); + } else { + nodes_to_remove.push_back(B_node); + } + new_gemm_input_defs[1] = B_node.MutableInputDefs()[0]; + } } nodes_to_remove.push_back(gemm_node); @@ -82,11 +102,14 @@ bool GemmTransposeFusion::SatisfyCondition(const Graph& graph, const Node& node, // Fusion can be applied if there is a transpose at either of the inputs for (auto node_it = node.InputNodesBegin(); node_it != node.InputNodesEnd(); ++node_it) { if (graph_utils::IsSupportedOptypeVersionAndDomain(*node_it, "Transpose", {1, 13}) && - node_it->GetOutputEdgesCount() == 1 && !graph.NodeProducesGraphOutput(*node_it) && // Make sure the two nodes do not span execution providers. node_it->GetExecutionProviderType() == node.GetExecutionProviderType()) { - return true; + // acceptable if all consumer(s) are gemm node(s) + std::vector gemm_nodes = graph_utils::FindChildrenByType(*node_it, "Gemm"); + if (gemm_nodes.size() == node_it->GetOutputEdgesCount()) { + return true; + } } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 839fe3adbb..3c773128bc 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1219,6 +1219,91 @@ TEST_F(GraphTransformationTests, GemmTransposeFusion2Inputs) { ASSERT_TRUE(new_input_defs[1]->Name() == "B"); } +// (A')'B' = AB' where transpose has multiple consumers +TEST_F(GraphTransformationTests, GemmTransposeFusion2OutputsFromTranspose) { + auto model_uri = MODEL_FOLDER "fusion/gemm_transpose_2outputs_from_transpose.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Transpose"], 2); + ASSERT_EQ(op_to_count["Gemm"], 1); + ASSERT_EQ(op_to_count["Identity"], 1); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(std::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Transpose"], 1); + ASSERT_EQ(op_to_count["Gemm"], 1); + ASSERT_EQ(op_to_count["Identity"], 1); + + auto gemm_node = + std::find_if( + graph.Nodes().cbegin(), graph.Nodes().cend(), + [](const Node& node) { return node.Name() == "Gemm_transformed"; }); + + auto& node = *gemm_node; + ASSERT_TRUE(node.OpType() == "Gemm"); + ASSERT_TRUE(static_cast(node.GetAttributes().at("transA").i())); + ASSERT_TRUE(static_cast(node.GetAttributes().at("transB").i())); + auto new_input_defs = node.InputDefs(); + ASSERT_TRUE(new_input_defs[0]->Name() == "tp0"); + ASSERT_TRUE(new_input_defs[1]->Name() == "B"); +} + +// (A')'B' = AB' and (B')'C = BC where transpose has multiple consumers +TEST_F(GraphTransformationTests, GemmTransposeFusion2OutputsFromTransposeTo2Gemms) { + auto model_uri = MODEL_FOLDER "fusion/gemm_transpose_2outputs_from_transpose_to_2gemms.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Transpose"], 2); + ASSERT_EQ(op_to_count["Gemm"], 2); + ASSERT_EQ(op_to_count["Identity"], 1); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(std::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Transpose"], 1); + ASSERT_EQ(op_to_count["Gemm"], 2); + ASSERT_EQ(op_to_count["Identity"], 1); + + auto gemm1_node = + std::find_if( + graph.Nodes().cbegin(), graph.Nodes().cend(), + [](const Node& node) { return node.Name() == "Gemm1_transformed"; }); + + auto& node1 = *gemm1_node; + ASSERT_TRUE(node1.OpType() == "Gemm"); + ASSERT_TRUE(static_cast(node1.GetAttributes().at("transA").i())); + ASSERT_TRUE(static_cast(node1.GetAttributes().at("transB").i())); + auto new_input_defs1 = node1.InputDefs(); + ASSERT_TRUE(new_input_defs1[0]->Name() == "tp0"); + ASSERT_TRUE(new_input_defs1[1]->Name() == "B"); + + auto gemm2_node = + std::find_if( + graph.Nodes().cbegin(), graph.Nodes().cend(), + [](const Node& node) { return node.Name() == "Gemm2_transformed"; }); + + auto& node2 = *gemm2_node; + ASSERT_TRUE(node2.OpType() == "Gemm"); + ASSERT_FALSE(static_cast(node2.GetAttributes().at("transA").i())); + ASSERT_FALSE(static_cast(node2.GetAttributes().at("transB").i())); + auto new_input_defs2 = node2.InputDefs(); + ASSERT_TRUE(new_input_defs2[0]->Name() == "B"); + ASSERT_TRUE(new_input_defs2[1]->Name() == "C"); +} + // (A'B)' = B'A TEST_F(GraphTransformationTests, GemmTransposeFusionOutput) { auto model_uri = MODEL_FOLDER "fusion/gemm_transpose_output_transposed.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/gemm_transpose_2outputs_from_transpose.onnx b/onnxruntime/test/testdata/transform/fusion/gemm_transpose_2outputs_from_transpose.onnx new file mode 100644 index 0000000000..6d679f0054 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/gemm_transpose_2outputs_from_transpose.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/gemm_transpose_2outputs_from_transpose_to_2gemms.onnx b/onnxruntime/test/testdata/transform/fusion/gemm_transpose_2outputs_from_transpose_to_2gemms.onnx new file mode 100644 index 0000000000..d28b0bb9c6 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/gemm_transpose_2outputs_from_transpose_to_2gemms.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen_2.py b/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen_2.py new file mode 100644 index 0000000000..0e1215376c --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen_2.py @@ -0,0 +1,75 @@ +import onnx +from onnx import helper +from onnx import TensorProto +from onnx import OperatorSetIdProto + +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 12 +# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = "com.microsoft" +opsets = [onnxdomain, msdomain] + + +def save(model_path, nodes, inputs, outputs, initializers): + graph = helper.make_graph( + nodes, + "TransposeGemmTest", + inputs, outputs, initializers) + + model = helper.make_model( + graph, opset_imports=opsets, producer_name="onnxruntime-test") + + onnx.save(model, model_path) + +# (A')'B' = AB' +def gemm_transpose_2outputs_from_transpose(model_path): + nodes = [ + helper.make_node("Transpose", ["A"], ["tp0"], "TransposeA"), + helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"), + helper.make_node("Gemm", ["tp0", "tp1"], ["output"], "Gemm", alpha=3.0, transA=1), + helper.make_node("Identity", ["tp0"], ["output2"], "IdentityAt"), + ] + + inputs = [ + helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K']) + ] + + outputs = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']), + helper.make_tensor_value_info("output2", TensorProto.FLOAT, ['K', 'M']) + ] + + save(model_path, nodes, inputs, outputs, []) + + +# (A')'B' = AB' and (B')'C = BC +def gemm_transpose_2outputs_from_transpose_to_2gemms(model_path): + nodes = [ + helper.make_node("Transpose", ["A"], ["tp0"], "TransposeA"), + helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"), + helper.make_node("Gemm", ["tp0", "tp1"], ["output"], "Gemm1", alpha=3.0, transA=1), + helper.make_node("Gemm", ["tp1", "C"], ["output3"], "Gemm2", alpha=3.0, transA=1), + helper.make_node("Identity", ["tp0"], ["output2"], "IdentityAt"), + ] + + inputs = [ + helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K']), + helper.make_tensor_value_info("C", TensorProto.FLOAT, ['K', 'L']) + ] + + outputs = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']), + helper.make_tensor_value_info("output2", TensorProto.FLOAT, ['K', 'M']), + helper.make_tensor_value_info("output3", TensorProto.FLOAT, ['N', 'L']) + ] + + save(model_path, nodes, inputs, outputs, []) + +gemm_transpose_2outputs_from_transpose("gemm_transpose_2outputs_from_transpose.onnx") +gemm_transpose_2outputs_from_transpose_to_2gemms("gemm_transpose_2outputs_from_transpose_to_2gemms.onnx") +