diff --git a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc index bdffc7701d..6fac01a033 100644 --- a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc @@ -86,11 +86,9 @@ static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_ } } -/* GetTransposeNodeFromCast: Interchange Cast and Transpose nodes in the graph and return Transpose node if possible -* Requirements to interchange Cast and Transpose nodes changing the order of the operations. -* 1. Both Cast and Transpose are one-output nodes (assuming both have one-input only) -* 2. Transpose only feeds the Cast node (and no other node) -* 3. Cast only feeds the MalMul node (and no other node) +/* ReorderCastAndTranspose: +* Interchange Cast and Transpose nodes in the graph and return the new Transpose node if possible else nullptr. +* * * Transform the following pattern * | @@ -118,11 +116,13 @@ static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_ * | * V */ -static Node* GetTransposeNodeFromCast(Graph& graph, Node* cast) { +static Node* ReorderCastAndTranspose(Graph& graph, Node* cast, + std::unordered_map& consumer_count, + std::deque& removed_nodes) { ORT_ENFORCE(cast != nullptr); auto transpose = GetTransposeNodeFromOutput(graph, *cast->MutableInputDefs()[0]); - if (transpose == nullptr || cast->GetOutputEdgesCount() != 1 || transpose->GetOutputEdgesCount() != 1) { + if (transpose == nullptr) { return nullptr; } NodeArg* cast_output = cast->MutableOutputDefs()[0]; @@ -159,10 +159,12 @@ static Node* GetTransposeNodeFromCast(Graph& graph, Node* cast) { &transpose->GetAttributes(), transpose->Domain()); + size_t consumers = UpdateConsumerCount(graph, transpose->MutableOutputDefs()[0], consumer_count); graph_utils::RemoveNodeOutputEdges(graph, *cast); - graph_utils::RemoveNodeOutputEdges(graph, *transpose); graph.RemoveNode(cast->Index()); - graph.RemoveNode(transpose->Index()); + if (consumers == 0) { + removed_nodes.push_front(transpose->Index()); + } return &new_transpose; } @@ -280,16 +282,17 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ NodeArg* right_input = node.MutableInputDefs()[1]; auto right = GetTransposeNodeFromOutput(graph, *right_input); - if (!left && !right) { + if (!left) { Node* left_node = graph.GetMutableProducerNode(left_input->Name()); if (left_node && left_node->OpType() == "Cast") { - left = GetTransposeNodeFromCast(graph, left_node); + left = ReorderCastAndTranspose(graph, left_node, consumer_count, removed_nodes); } - if (!left) { - Node* right_node = graph.GetMutableProducerNode(right_input->Name()); - if (right_node && right_node->OpType() == "Cast") { - right = GetTransposeNodeFromCast(graph, right_node); - } + } + + if (!right) { + Node* right_node = graph.GetMutableProducerNode(right_input->Name()); + if (right_node && right_node->OpType() == "Cast") { + right = ReorderCastAndTranspose(graph, right_node, consumer_count, removed_nodes); } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 1c88e15e6e..049616570b 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -847,12 +847,19 @@ TEST_F(GraphTransformationTests, TransposeMatmulFusion) { TEST_F(GraphTransformationTests, TransposeCastMatmulFusion) { const std::vector model_uris = { MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion0.onnx", // Test fusion from the right input - MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion1.onnx" // Test fusion from the left input + MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion1.onnx", // Test fusion from the left input + MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion2.onnx", // Test fusion both from the left and right inputs + MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion3.onnx", // Cast nodes feed multiple MatMul nodes. + MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion4.onnx", // Cast nodes feed one MatMul node and + // the Transpose nodes feed another MatMul node. + MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion5.onnx" // One Cast node and one Transpose node feed each + // MatMul nodes. }; for (const auto& model_uri : model_uris) { std::shared_ptr p_model; ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK()); Graph& graph = p_model->MainGraph(); + std::map orig_op_to_count = CountOpsInGraph(graph); // Original op count onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); @@ -861,8 +868,8 @@ TEST_F(GraphTransformationTests, TransposeCastMatmulFusion) { std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Transpose"] == 0); ASSERT_TRUE(op_to_count["MatMul"] == 0); - ASSERT_TRUE(op_to_count["Cast"] == 2); - ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 1); + ASSERT_TRUE(op_to_count["Cast"] == orig_op_to_count["Cast"]); + ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == orig_op_to_count["MatMul"]); } } diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion0.onnx b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion0.onnx index 8d5c73c82f..d2d4b7b32e 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion0.onnx and b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion0.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion1.onnx b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion1.onnx index 1ff3a3eec6..6911cc4b9a 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion1.onnx and b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion1.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion2.onnx b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion2.onnx new file mode 100644 index 0000000000..8004f68f78 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion2.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion3.onnx b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion3.onnx new file mode 100644 index 0000000000..518d7f89d9 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion3.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion4.onnx b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion4.onnx new file mode 100644 index 0000000000..e2d295c288 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion4.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion5.onnx b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion5.onnx new file mode 100644 index 0000000000..69481810f2 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion5.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py b/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py index 2f21d4917e..dda0efb108 100644 --- a/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py @@ -129,46 +129,95 @@ gen_with_preserved_transpose( def gen_transpose_fusion_with_cast(model_path): - nodes = [ - helper.make_node( - "Cast", - ["input_1"], - ["casted_input_1"], - to = 10 - ), - helper.make_node( - "Transpose", - ["input_0"], - ["transposed_input_0"], - perm = [0, 1, 3, 2]), - helper.make_node( - "Cast", - ["transposed_input_0"], - ["transposed_casted_input_0"], - to = 10), - helper.make_node( - "MatMul", - ["transposed_casted_input_0", "casted_input_1"], - ["output_0"]) - ] + cast_1 = helper.make_node( + "Cast", + ["input_1"], + ["casted_input_1"], + "Cast_1", + to = TensorProto.FLOAT16) + transpose_0 = helper.make_node( + "Transpose", + ["input_0"], + ["transposed_input_0"], + "Transpose_0", + perm = [0, 1, 3, 2]) + cast_0 = helper.make_node( + "Cast", + ["transposed_input_0"], + ["transposed_casted_input_0"], + "Cast_0", + to = TensorProto.FLOAT16) + matmul_0 = helper.make_node( + "MatMul", + ["transposed_casted_input_0", "casted_input_1"], + ["output_0"], + "MatMul_0") - inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, [3, 2, 'K', 'M']), - helper.make_tensor_value_info( - "input_1", TensorProto.FLOAT, [3, 2, 'K', 'N']) - ] - - outputs = [ - helper.make_tensor_value_info( - "output_0", TensorProto.FLOAT16, [3, 2, 'M', 'N']) - ] + nodes = [transpose_0, cast_0, cast_1, matmul_0] + input_0 = helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [3, 2, 'N', 'N']) + input_1 = helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [3, 2, 'N', 'N']) + inputs = [input_0, input_1] + output_0 = helper.make_tensor_value_info("output_0", TensorProto.FLOAT16, [3, 2, 'N', 'N']) + outputs = [output_0] + # Testcase0: First input of MatMul is transposed save(model_path + "0.onnx", nodes, inputs, outputs, []) - # Re-arragne nodes so that the transpose is on left input of matmul - nodes = nodes[1:3] + nodes[0:1] + nodes[3:] + + # Testcase1: Re-arragne nodes so that the transpose is on second input of matmul + transpose_1 = helper.make_node( + "Transpose", + ["input_1"], + ["transposed_input_1"], + "Transpose_1", + perm = [0, 1, 3, 2]) + cast_1.input[0] = "transposed_input_1" + cast_1.output[0] = "transposed_casted_input_1" + cast_0.input[0] = "input_0" + cast_0.output[0] = "casted_input_0" + matmul_0.input[0] = cast_0.output[0] + matmul_0.input[1] = cast_1.output[0] + nodes = [cast_0, transpose_1, cast_1, matmul_0] save(model_path + "1.onnx", nodes, inputs, outputs, []) + # Testcase2: Create an example with two Cast-ed Transpose-ed inputs feeding a MatMul + cast_0.input[0] = "transposed_input_0" + cast_0.output[0] = "transposed_casted_input_0" + matmul_0.input[0] = cast_0.output[0] + nodes = [transpose_0, cast_0, transpose_1, cast_1, matmul_0] + save(model_path + "2.onnx", nodes, inputs, outputs, []) + + # Testcase3: Create a second MatMul node using the outputs from the same Cast nodes as before + # with each Cast node feeding more than one node. + nodes.append(helper.make_node( + "MatMul", + ["transposed_casted_input_0", "transposed_casted_input_1"], + ["output_1"], + "MatMul_1")) + output_1 = helper.make_tensor_value_info("output_1", TensorProto.FLOAT16, [3, 2, 'N', 'N']) + outputs.append(output_1) + save(model_path + "3.onnx", nodes, inputs, outputs, []) + + # Testcase4: The second MatMul uses transposed inputs without cast. + nodes.pop() + outputs.pop() + matmul_1 = helper.make_node( + "MatMul", + ["transposed_input_0", "transposed_input_1"], + ["output_1"], + "MatMul_1") + nodes.append(matmul_1) + + outputs.append(helper.make_tensor_value_info( + "output_1", TensorProto.FLOAT, [3, 2, 'N', 'N'])) + save(model_path + "4.onnx", nodes, inputs, outputs, []) + + # Testcase5: Each MatMul uses outputs from a Cast and a Transpose + input_0.type.tensor_type.elem_type = TensorProto.FLOAT16 + cast_0.attribute[0].i = TensorProto.FLOAT + matmul_0.input[0] = "transposed_input_0" + matmul_1.input[0] = "transposed_casted_input_0" + output_1.type.tensor_type.elem_type = TensorProto.FLOAT + save(model_path + "5.onnx", nodes, inputs, outputs, []) gen_transpose_fusion_with_cast( "transpose_cast_matmul_4d_fusion")