diff --git a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc index acd43b1df3..bdffc7701d 100644 --- a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc @@ -86,6 +86,177 @@ static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_ } } +/* GetTransposeNodeFromCast: Interchange Cast and Transpose nodes in the graph and return Transpose node if possible +* Requirements to interchange Cast and Transpose nodes changing the order of the operations. +* 1. Both Cast and Transpose are one-output nodes (assuming both have one-input only) +* 2. Transpose only feeds the Cast node (and no other node) +* 3. Cast only feeds the MalMul node (and no other node) +* +* Transform the following pattern +* | +* _____|______ +* |Transpose | +* |__________| +* | +* | +* _____V______ +* | Cast | +* |__________| +* | +* V +* +* to +* | +* _____|______ +* | Cast | +* |__________| +* | +* | +* _____V______ +* | Transpose| +* |__________| +* | +* V +*/ +static Node* GetTransposeNodeFromCast(Graph& graph, Node* cast) { + + ORT_ENFORCE(cast != nullptr); + auto transpose = GetTransposeNodeFromOutput(graph, *cast->MutableInputDefs()[0]); + if (transpose == nullptr || cast->GetOutputEdgesCount() != 1 || transpose->GetOutputEdgesCount() != 1) { + return nullptr; + } + NodeArg* cast_output = cast->MutableOutputDefs()[0]; + NodeArg* transpose_input = transpose->MutableInputDefs()[0]; + + // Create a new NodeArg to feed the output from the new Cast to the new Transpose. + // The shape of the new NodeArg is same as the original input to Transport but type + // should match that of the output from the original Cast. + + auto new_cast_output_type_proto = *transpose_input->TypeAsProto(); + const ONNX_NAMESPACE::TensorProto_DataType element_type = + static_cast(cast_output->TypeAsProto()->tensor_type().elem_type()); + new_cast_output_type_proto.mutable_tensor_type()->set_elem_type(element_type); + auto& new_cast_output = graph.GetOrCreateNodeArg(cast_output->Name() + "_transformed", &new_cast_output_type_proto); + + const std::vector new_cast_input_defs {transpose_input}; + const std::vector new_cast_output_defs {&new_cast_output}; + const std::vector new_transpose_input_defs = {&new_cast_output}; + const std::vector new_transpose_output_defs = {cast_output}; + + (void) graph.AddNode(graph.GenerateNodeName(cast->Name() + "_transformed"), + cast->OpType(), + "Created a new Cast node to interchange Cast and Transpose nodes", + new_cast_input_defs, + new_cast_output_defs, + &cast->GetAttributes(), + cast->Domain()); + + Node& new_transpose = graph.AddNode(graph.GenerateNodeName(transpose->Name() + "_transformed"), + transpose->OpType(), + "Created a new Transpose node to interchange Cast and Transpose nodes", + new_transpose_input_defs, + new_transpose_output_defs, + &transpose->GetAttributes(), + transpose->Domain()); + + graph_utils::RemoveNodeOutputEdges(graph, *cast); + graph_utils::RemoveNodeOutputEdges(graph, *transpose); + graph.RemoveNode(cast->Index()); + graph.RemoveNode(transpose->Index()); + return &new_transpose; +} + +/********************************************************************************************* + +Case I: The followin is a scenario where Transpose output feeds MatMul. The Transpose input can be either on the left or right. + The input graph + __________ __________ + | input0 | | input1 | + |________| |________| + | | + | | + | | + _____V______ | + |Transpose | | + |__________| | + | | + | | + |______________ _____________| + | | + | | + | | + __V___________V__ + | MatMul | + |_______________| + | + V + is transformed to the following + + __________ __________ + | input0 | | input1 | + |________| |________| + | | + | | + | | + |_____________ _____________| + | | + | | + | | + __V___________V__ + | FusedMatMul | + |_______________| + | + V + +Case II: The output of Tanspose feeds Cast and the output from the Cast feeds MatMul + The input graph + __________ __________ + | input0 | | input1 | + |________| |________| + | | + | | + _____V______ | + |Transpose | | + |__________| | + | | + | | + _____V______ | + | Cast | | + |__________| | + | | + |______________ _____________| + | | + | | + | | + __V___________V__ + | MatMul | + |_______________| + | + V + is transformed to the following + + __________ __________ + | input0 | | input1 | + |________| |________| + | | + | | + | | + _____V______ | + | Cast | | + |__________| | + | | + |______________ _____________| + | | + | | + | | + __V___________V__ + | FusedMatMul | + |_______________| + | + V + +********************************************************************************************************************/ + Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -109,6 +280,19 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ NodeArg* right_input = node.MutableInputDefs()[1]; auto right = GetTransposeNodeFromOutput(graph, *right_input); + if (!left && !right) { + Node* left_node = graph.GetMutableProducerNode(left_input->Name()); + if (left_node && left_node->OpType() == "Cast") { + left = GetTransposeNodeFromCast(graph, left_node); + } + if (!left) { + Node* right_node = graph.GetMutableProducerNode(right_input->Name()); + if (right_node && right_node->OpType() == "Cast") { + right = GetTransposeNodeFromCast(graph, right_node); + } + } + } + if (!left && !right) { continue; } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 87d2ec6d87..1c88e15e6e 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -844,6 +844,28 @@ TEST_F(GraphTransformationTests, TransposeMatmulFusion) { ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 1); } +TEST_F(GraphTransformationTests, TransposeCastMatmulFusion) { + const std::vector model_uris = { + MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion0.onnx", // Test fusion from the right input + MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion1.onnx" // Test fusion from the left input + }; + for (const auto& model_uri : model_uris) { + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); + ASSERT_TRUE(ret.IsOK()); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Transpose"] == 0); + ASSERT_TRUE(op_to_count["MatMul"] == 0); + ASSERT_TRUE(op_to_count["Cast"] == 2); + ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 1); + } +} + TEST_F(GraphTransformationTests, TransposeMatmulFusionOnTwoTranspose) { auto model_uri = MODEL_FOLDER "fusion/transpose_matmul_4d_fusion_2_transpose.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion0.onnx b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion0.onnx new file mode 100644 index 0000000000..8d5c73c82f Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion0.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion1.onnx b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion1.onnx new file mode 100644 index 0000000000..1ff3a3eec6 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion1.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py b/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py index ea3f6e1c28..2f21d4917e 100644 --- a/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py @@ -126,3 +126,49 @@ def gen_with_preserved_transpose(model_path): gen_with_preserved_transpose( "transpose_matmul_2d_fusion_with_preserved_transpose.onnx") + + +def gen_transpose_fusion_with_cast(model_path): + nodes = [ + helper.make_node( + "Cast", + ["input_1"], + ["casted_input_1"], + to = 10 + ), + helper.make_node( + "Transpose", + ["input_0"], + ["transposed_input_0"], + perm = [0, 1, 3, 2]), + helper.make_node( + "Cast", + ["transposed_input_0"], + ["transposed_casted_input_0"], + to = 10), + helper.make_node( + "MatMul", + ["transposed_casted_input_0", "casted_input_1"], + ["output_0"]) + ] + + inputs = [ + helper.make_tensor_value_info( + "input_0", TensorProto.FLOAT, [3, 2, 'K', 'M']), + helper.make_tensor_value_info( + "input_1", TensorProto.FLOAT, [3, 2, 'K', 'N']) + ] + + outputs = [ + helper.make_tensor_value_info( + "output_0", TensorProto.FLOAT16, [3, 2, 'M', 'N']) + ] + + save(model_path + "0.onnx", nodes, inputs, outputs, []) + # Re-arragne nodes so that the transpose is on left input of matmul + nodes = nodes[1:3] + nodes[0:1] + nodes[3:] + save(model_path + "1.onnx", nodes, inputs, outputs, []) + + +gen_transpose_fusion_with_cast( + "transpose_cast_matmul_4d_fusion")