Enhance Transpose, Cast and MatMul fusion when Cast and/or Fusion feeds multiple nodes. (#7021)

* Added new Transpose+Cast+MatMul => Cast+FusedMatMul test scenarios.

* The Cast node may feed more than one node.

* Transpose node may feed multiple nodes and still may be fused with MatMul nodes.
This commit is contained in:
satyajandhyala 2021-03-18 11:41:58 -07:00 committed by GitHub
parent 1a1dd4843d
commit 8bc275e93f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 113 additions and 54 deletions

View file

@ -86,11 +86,9 @@ static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_
}
}
/* GetTransposeNodeFromCast: Interchange Cast and Transpose nodes in the graph and return Transpose node if possible
* Requirements to interchange Cast and Transpose nodes changing the order of the operations.
* 1. Both Cast and Transpose are one-output nodes (assuming both have one-input only)
* 2. Transpose only feeds the Cast node (and no other node)
* 3. Cast only feeds the MalMul node (and no other node)
/* ReorderCastAndTranspose:
* Interchange Cast and Transpose nodes in the graph and return the new Transpose node if possible else nullptr.
*
*
* Transform the following pattern
* |
@ -118,11 +116,13 @@ static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_
* |
* V
*/
static Node* GetTransposeNodeFromCast(Graph& graph, Node* cast) {
static Node* ReorderCastAndTranspose(Graph& graph, Node* cast,
std::unordered_map<NodeArg*, size_t>& consumer_count,
std::deque<onnxruntime::NodeIndex>& removed_nodes) {
ORT_ENFORCE(cast != nullptr);
auto transpose = GetTransposeNodeFromOutput(graph, *cast->MutableInputDefs()[0]);
if (transpose == nullptr || cast->GetOutputEdgesCount() != 1 || transpose->GetOutputEdgesCount() != 1) {
if (transpose == nullptr) {
return nullptr;
}
NodeArg* cast_output = cast->MutableOutputDefs()[0];
@ -159,10 +159,12 @@ static Node* GetTransposeNodeFromCast(Graph& graph, Node* cast) {
&transpose->GetAttributes(),
transpose->Domain());
size_t consumers = UpdateConsumerCount(graph, transpose->MutableOutputDefs()[0], consumer_count);
graph_utils::RemoveNodeOutputEdges(graph, *cast);
graph_utils::RemoveNodeOutputEdges(graph, *transpose);
graph.RemoveNode(cast->Index());
graph.RemoveNode(transpose->Index());
if (consumers == 0) {
removed_nodes.push_front(transpose->Index());
}
return &new_transpose;
}
@ -280,16 +282,17 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
NodeArg* right_input = node.MutableInputDefs()[1];
auto right = GetTransposeNodeFromOutput(graph, *right_input);
if (!left && !right) {
if (!left) {
Node* left_node = graph.GetMutableProducerNode(left_input->Name());
if (left_node && left_node->OpType() == "Cast") {
left = GetTransposeNodeFromCast(graph, left_node);
left = ReorderCastAndTranspose(graph, left_node, consumer_count, removed_nodes);
}
if (!left) {
Node* right_node = graph.GetMutableProducerNode(right_input->Name());
if (right_node && right_node->OpType() == "Cast") {
right = GetTransposeNodeFromCast(graph, right_node);
}
}
if (!right) {
Node* right_node = graph.GetMutableProducerNode(right_input->Name());
if (right_node && right_node->OpType() == "Cast") {
right = ReorderCastAndTranspose(graph, right_node, consumer_count, removed_nodes);
}
}

View file

@ -847,12 +847,19 @@ TEST_F(GraphTransformationTests, TransposeMatmulFusion) {
TEST_F(GraphTransformationTests, TransposeCastMatmulFusion) {
const std::vector<PathString> model_uris = {
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion0.onnx", // Test fusion from the right input
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion1.onnx" // Test fusion from the left input
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion1.onnx", // Test fusion from the left input
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion2.onnx", // Test fusion both from the left and right inputs
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion3.onnx", // Cast nodes feed multiple MatMul nodes.
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion4.onnx", // Cast nodes feed one MatMul node and
// the Transpose nodes feed another MatMul node.
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion5.onnx" // One Cast node and one Transpose node feed each
// MatMul nodes.
};
for (const auto& model_uri : model_uris) {
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
Graph& graph = p_model->MainGraph();
std::map<std::string, int> orig_op_to_count = CountOpsInGraph(graph); // Original op count
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<MatmulTransposeFusion>(), TransformerLevel::Level1);
@ -861,8 +868,8 @@ TEST_F(GraphTransformationTests, TransposeCastMatmulFusion) {
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Transpose"] == 0);
ASSERT_TRUE(op_to_count["MatMul"] == 0);
ASSERT_TRUE(op_to_count["Cast"] == 2);
ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 1);
ASSERT_TRUE(op_to_count["Cast"] == orig_op_to_count["Cast"]);
ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == orig_op_to_count["MatMul"]);
}
}

View file

@ -129,46 +129,95 @@ gen_with_preserved_transpose(
def gen_transpose_fusion_with_cast(model_path):
nodes = [
helper.make_node(
"Cast",
["input_1"],
["casted_input_1"],
to = 10
),
helper.make_node(
"Transpose",
["input_0"],
["transposed_input_0"],
perm = [0, 1, 3, 2]),
helper.make_node(
"Cast",
["transposed_input_0"],
["transposed_casted_input_0"],
to = 10),
helper.make_node(
"MatMul",
["transposed_casted_input_0", "casted_input_1"],
["output_0"])
]
cast_1 = helper.make_node(
"Cast",
["input_1"],
["casted_input_1"],
"Cast_1",
to = TensorProto.FLOAT16)
transpose_0 = helper.make_node(
"Transpose",
["input_0"],
["transposed_input_0"],
"Transpose_0",
perm = [0, 1, 3, 2])
cast_0 = helper.make_node(
"Cast",
["transposed_input_0"],
["transposed_casted_input_0"],
"Cast_0",
to = TensorProto.FLOAT16)
matmul_0 = helper.make_node(
"MatMul",
["transposed_casted_input_0", "casted_input_1"],
["output_0"],
"MatMul_0")
inputs = [
helper.make_tensor_value_info(
"input_0", TensorProto.FLOAT, [3, 2, 'K', 'M']),
helper.make_tensor_value_info(
"input_1", TensorProto.FLOAT, [3, 2, 'K', 'N'])
]
outputs = [
helper.make_tensor_value_info(
"output_0", TensorProto.FLOAT16, [3, 2, 'M', 'N'])
]
nodes = [transpose_0, cast_0, cast_1, matmul_0]
input_0 = helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [3, 2, 'N', 'N'])
input_1 = helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [3, 2, 'N', 'N'])
inputs = [input_0, input_1]
output_0 = helper.make_tensor_value_info("output_0", TensorProto.FLOAT16, [3, 2, 'N', 'N'])
outputs = [output_0]
# Testcase0: First input of MatMul is transposed
save(model_path + "0.onnx", nodes, inputs, outputs, [])
# Re-arragne nodes so that the transpose is on left input of matmul
nodes = nodes[1:3] + nodes[0:1] + nodes[3:]
# Testcase1: Re-arragne nodes so that the transpose is on second input of matmul
transpose_1 = helper.make_node(
"Transpose",
["input_1"],
["transposed_input_1"],
"Transpose_1",
perm = [0, 1, 3, 2])
cast_1.input[0] = "transposed_input_1"
cast_1.output[0] = "transposed_casted_input_1"
cast_0.input[0] = "input_0"
cast_0.output[0] = "casted_input_0"
matmul_0.input[0] = cast_0.output[0]
matmul_0.input[1] = cast_1.output[0]
nodes = [cast_0, transpose_1, cast_1, matmul_0]
save(model_path + "1.onnx", nodes, inputs, outputs, [])
# Testcase2: Create an example with two Cast-ed Transpose-ed inputs feeding a MatMul
cast_0.input[0] = "transposed_input_0"
cast_0.output[0] = "transposed_casted_input_0"
matmul_0.input[0] = cast_0.output[0]
nodes = [transpose_0, cast_0, transpose_1, cast_1, matmul_0]
save(model_path + "2.onnx", nodes, inputs, outputs, [])
# Testcase3: Create a second MatMul node using the outputs from the same Cast nodes as before
# with each Cast node feeding more than one node.
nodes.append(helper.make_node(
"MatMul",
["transposed_casted_input_0", "transposed_casted_input_1"],
["output_1"],
"MatMul_1"))
output_1 = helper.make_tensor_value_info("output_1", TensorProto.FLOAT16, [3, 2, 'N', 'N'])
outputs.append(output_1)
save(model_path + "3.onnx", nodes, inputs, outputs, [])
# Testcase4: The second MatMul uses transposed inputs without cast.
nodes.pop()
outputs.pop()
matmul_1 = helper.make_node(
"MatMul",
["transposed_input_0", "transposed_input_1"],
["output_1"],
"MatMul_1")
nodes.append(matmul_1)
outputs.append(helper.make_tensor_value_info(
"output_1", TensorProto.FLOAT, [3, 2, 'N', 'N']))
save(model_path + "4.onnx", nodes, inputs, outputs, [])
# Testcase5: Each MatMul uses outputs from a Cast and a Transpose
input_0.type.tensor_type.elem_type = TensorProto.FLOAT16
cast_0.attribute[0].i = TensorProto.FLOAT
matmul_0.input[0] = "transposed_input_0"
matmul_1.input[0] = "transposed_casted_input_0"
output_1.type.tensor_type.elem_type = TensorProto.FLOAT
save(model_path + "5.onnx", nodes, inputs, outputs, [])
gen_transpose_fusion_with_cast(
"transpose_cast_matmul_4d_fusion")