mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Enhance Transpose, Cast and MatMul fusion when Cast and/or Fusion feeds multiple nodes. (#7021)
* Added new Transpose+Cast+MatMul => Cast+FusedMatMul test scenarios. * The Cast node may feed more than one node. * Transpose node may feed multiple nodes and still may be fused with MatMul nodes.
This commit is contained in:
parent
1a1dd4843d
commit
8bc275e93f
9 changed files with 113 additions and 54 deletions
|
|
@ -86,11 +86,9 @@ static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_
|
|||
}
|
||||
}
|
||||
|
||||
/* GetTransposeNodeFromCast: Interchange Cast and Transpose nodes in the graph and return Transpose node if possible
|
||||
* Requirements to interchange Cast and Transpose nodes changing the order of the operations.
|
||||
* 1. Both Cast and Transpose are one-output nodes (assuming both have one-input only)
|
||||
* 2. Transpose only feeds the Cast node (and no other node)
|
||||
* 3. Cast only feeds the MalMul node (and no other node)
|
||||
/* ReorderCastAndTranspose:
|
||||
* Interchange Cast and Transpose nodes in the graph and return the new Transpose node if possible else nullptr.
|
||||
*
|
||||
*
|
||||
* Transform the following pattern
|
||||
* |
|
||||
|
|
@ -118,11 +116,13 @@ static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_
|
|||
* |
|
||||
* V
|
||||
*/
|
||||
static Node* GetTransposeNodeFromCast(Graph& graph, Node* cast) {
|
||||
static Node* ReorderCastAndTranspose(Graph& graph, Node* cast,
|
||||
std::unordered_map<NodeArg*, size_t>& consumer_count,
|
||||
std::deque<onnxruntime::NodeIndex>& removed_nodes) {
|
||||
|
||||
ORT_ENFORCE(cast != nullptr);
|
||||
auto transpose = GetTransposeNodeFromOutput(graph, *cast->MutableInputDefs()[0]);
|
||||
if (transpose == nullptr || cast->GetOutputEdgesCount() != 1 || transpose->GetOutputEdgesCount() != 1) {
|
||||
if (transpose == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
NodeArg* cast_output = cast->MutableOutputDefs()[0];
|
||||
|
|
@ -159,10 +159,12 @@ static Node* GetTransposeNodeFromCast(Graph& graph, Node* cast) {
|
|||
&transpose->GetAttributes(),
|
||||
transpose->Domain());
|
||||
|
||||
size_t consumers = UpdateConsumerCount(graph, transpose->MutableOutputDefs()[0], consumer_count);
|
||||
graph_utils::RemoveNodeOutputEdges(graph, *cast);
|
||||
graph_utils::RemoveNodeOutputEdges(graph, *transpose);
|
||||
graph.RemoveNode(cast->Index());
|
||||
graph.RemoveNode(transpose->Index());
|
||||
if (consumers == 0) {
|
||||
removed_nodes.push_front(transpose->Index());
|
||||
}
|
||||
return &new_transpose;
|
||||
}
|
||||
|
||||
|
|
@ -280,16 +282,17 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
|
|||
NodeArg* right_input = node.MutableInputDefs()[1];
|
||||
auto right = GetTransposeNodeFromOutput(graph, *right_input);
|
||||
|
||||
if (!left && !right) {
|
||||
if (!left) {
|
||||
Node* left_node = graph.GetMutableProducerNode(left_input->Name());
|
||||
if (left_node && left_node->OpType() == "Cast") {
|
||||
left = GetTransposeNodeFromCast(graph, left_node);
|
||||
left = ReorderCastAndTranspose(graph, left_node, consumer_count, removed_nodes);
|
||||
}
|
||||
if (!left) {
|
||||
Node* right_node = graph.GetMutableProducerNode(right_input->Name());
|
||||
if (right_node && right_node->OpType() == "Cast") {
|
||||
right = GetTransposeNodeFromCast(graph, right_node);
|
||||
}
|
||||
}
|
||||
|
||||
if (!right) {
|
||||
Node* right_node = graph.GetMutableProducerNode(right_input->Name());
|
||||
if (right_node && right_node->OpType() == "Cast") {
|
||||
right = ReorderCastAndTranspose(graph, right_node, consumer_count, removed_nodes);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -847,12 +847,19 @@ TEST_F(GraphTransformationTests, TransposeMatmulFusion) {
|
|||
TEST_F(GraphTransformationTests, TransposeCastMatmulFusion) {
|
||||
const std::vector<PathString> model_uris = {
|
||||
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion0.onnx", // Test fusion from the right input
|
||||
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion1.onnx" // Test fusion from the left input
|
||||
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion1.onnx", // Test fusion from the left input
|
||||
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion2.onnx", // Test fusion both from the left and right inputs
|
||||
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion3.onnx", // Cast nodes feed multiple MatMul nodes.
|
||||
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion4.onnx", // Cast nodes feed one MatMul node and
|
||||
// the Transpose nodes feed another MatMul node.
|
||||
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion5.onnx" // One Cast node and one Transpose node feed each
|
||||
// MatMul nodes.
|
||||
};
|
||||
for (const auto& model_uri : model_uris) {
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
std::map<std::string, int> orig_op_to_count = CountOpsInGraph(graph); // Original op count
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MatmulTransposeFusion>(), TransformerLevel::Level1);
|
||||
|
|
@ -861,8 +868,8 @@ TEST_F(GraphTransformationTests, TransposeCastMatmulFusion) {
|
|||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Transpose"] == 0);
|
||||
ASSERT_TRUE(op_to_count["MatMul"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Cast"] == 2);
|
||||
ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 1);
|
||||
ASSERT_TRUE(op_to_count["Cast"] == orig_op_to_count["Cast"]);
|
||||
ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == orig_op_to_count["MatMul"]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion2.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion2.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion3.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion3.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion4.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion4.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion5.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion5.onnx
vendored
Normal file
Binary file not shown.
|
|
@ -129,46 +129,95 @@ gen_with_preserved_transpose(
|
|||
|
||||
|
||||
def gen_transpose_fusion_with_cast(model_path):
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
"Cast",
|
||||
["input_1"],
|
||||
["casted_input_1"],
|
||||
to = 10
|
||||
),
|
||||
helper.make_node(
|
||||
"Transpose",
|
||||
["input_0"],
|
||||
["transposed_input_0"],
|
||||
perm = [0, 1, 3, 2]),
|
||||
helper.make_node(
|
||||
"Cast",
|
||||
["transposed_input_0"],
|
||||
["transposed_casted_input_0"],
|
||||
to = 10),
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
["transposed_casted_input_0", "casted_input_1"],
|
||||
["output_0"])
|
||||
]
|
||||
cast_1 = helper.make_node(
|
||||
"Cast",
|
||||
["input_1"],
|
||||
["casted_input_1"],
|
||||
"Cast_1",
|
||||
to = TensorProto.FLOAT16)
|
||||
transpose_0 = helper.make_node(
|
||||
"Transpose",
|
||||
["input_0"],
|
||||
["transposed_input_0"],
|
||||
"Transpose_0",
|
||||
perm = [0, 1, 3, 2])
|
||||
cast_0 = helper.make_node(
|
||||
"Cast",
|
||||
["transposed_input_0"],
|
||||
["transposed_casted_input_0"],
|
||||
"Cast_0",
|
||||
to = TensorProto.FLOAT16)
|
||||
matmul_0 = helper.make_node(
|
||||
"MatMul",
|
||||
["transposed_casted_input_0", "casted_input_1"],
|
||||
["output_0"],
|
||||
"MatMul_0")
|
||||
|
||||
inputs = [
|
||||
helper.make_tensor_value_info(
|
||||
"input_0", TensorProto.FLOAT, [3, 2, 'K', 'M']),
|
||||
helper.make_tensor_value_info(
|
||||
"input_1", TensorProto.FLOAT, [3, 2, 'K', 'N'])
|
||||
]
|
||||
|
||||
outputs = [
|
||||
helper.make_tensor_value_info(
|
||||
"output_0", TensorProto.FLOAT16, [3, 2, 'M', 'N'])
|
||||
]
|
||||
nodes = [transpose_0, cast_0, cast_1, matmul_0]
|
||||
|
||||
input_0 = helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [3, 2, 'N', 'N'])
|
||||
input_1 = helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [3, 2, 'N', 'N'])
|
||||
inputs = [input_0, input_1]
|
||||
output_0 = helper.make_tensor_value_info("output_0", TensorProto.FLOAT16, [3, 2, 'N', 'N'])
|
||||
outputs = [output_0]
|
||||
# Testcase0: First input of MatMul is transposed
|
||||
save(model_path + "0.onnx", nodes, inputs, outputs, [])
|
||||
# Re-arragne nodes so that the transpose is on left input of matmul
|
||||
nodes = nodes[1:3] + nodes[0:1] + nodes[3:]
|
||||
|
||||
# Testcase1: Re-arragne nodes so that the transpose is on second input of matmul
|
||||
transpose_1 = helper.make_node(
|
||||
"Transpose",
|
||||
["input_1"],
|
||||
["transposed_input_1"],
|
||||
"Transpose_1",
|
||||
perm = [0, 1, 3, 2])
|
||||
cast_1.input[0] = "transposed_input_1"
|
||||
cast_1.output[0] = "transposed_casted_input_1"
|
||||
cast_0.input[0] = "input_0"
|
||||
cast_0.output[0] = "casted_input_0"
|
||||
matmul_0.input[0] = cast_0.output[0]
|
||||
matmul_0.input[1] = cast_1.output[0]
|
||||
nodes = [cast_0, transpose_1, cast_1, matmul_0]
|
||||
save(model_path + "1.onnx", nodes, inputs, outputs, [])
|
||||
|
||||
# Testcase2: Create an example with two Cast-ed Transpose-ed inputs feeding a MatMul
|
||||
cast_0.input[0] = "transposed_input_0"
|
||||
cast_0.output[0] = "transposed_casted_input_0"
|
||||
matmul_0.input[0] = cast_0.output[0]
|
||||
nodes = [transpose_0, cast_0, transpose_1, cast_1, matmul_0]
|
||||
save(model_path + "2.onnx", nodes, inputs, outputs, [])
|
||||
|
||||
# Testcase3: Create a second MatMul node using the outputs from the same Cast nodes as before
|
||||
# with each Cast node feeding more than one node.
|
||||
nodes.append(helper.make_node(
|
||||
"MatMul",
|
||||
["transposed_casted_input_0", "transposed_casted_input_1"],
|
||||
["output_1"],
|
||||
"MatMul_1"))
|
||||
output_1 = helper.make_tensor_value_info("output_1", TensorProto.FLOAT16, [3, 2, 'N', 'N'])
|
||||
outputs.append(output_1)
|
||||
save(model_path + "3.onnx", nodes, inputs, outputs, [])
|
||||
|
||||
# Testcase4: The second MatMul uses transposed inputs without cast.
|
||||
nodes.pop()
|
||||
outputs.pop()
|
||||
matmul_1 = helper.make_node(
|
||||
"MatMul",
|
||||
["transposed_input_0", "transposed_input_1"],
|
||||
["output_1"],
|
||||
"MatMul_1")
|
||||
nodes.append(matmul_1)
|
||||
|
||||
outputs.append(helper.make_tensor_value_info(
|
||||
"output_1", TensorProto.FLOAT, [3, 2, 'N', 'N']))
|
||||
save(model_path + "4.onnx", nodes, inputs, outputs, [])
|
||||
|
||||
# Testcase5: Each MatMul uses outputs from a Cast and a Transpose
|
||||
input_0.type.tensor_type.elem_type = TensorProto.FLOAT16
|
||||
cast_0.attribute[0].i = TensorProto.FLOAT
|
||||
matmul_0.input[0] = "transposed_input_0"
|
||||
matmul_1.input[0] = "transposed_casted_input_0"
|
||||
output_1.type.tensor_type.elem_type = TensorProto.FLOAT
|
||||
save(model_path + "5.onnx", nodes, inputs, outputs, [])
|
||||
|
||||
gen_transpose_fusion_with_cast(
|
||||
"transpose_cast_matmul_4d_fusion")
|
||||
|
|
|
|||
Loading…
Reference in a new issue