mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Gemm/Transpose fusion - additional pattern coverage (#8941)
* gemm transpose fixes * enforce condition * add comments * rm redundant code Co-authored-by: Ethan Tao <ettao@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
eebcc20f10
commit
53eb79f9f6
5 changed files with 193 additions and 10 deletions
|
|
@ -26,17 +26,37 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m
|
|||
|
||||
// check if input A is a Transpose
|
||||
if (A_node_ptr != nullptr && A_node_ptr->OpType() == "Transpose") {
|
||||
Node& A_node = *graph.GetNode(A_node_ptr->Index());
|
||||
transA = !transA;
|
||||
nodes_to_remove.push_back(A_node);
|
||||
new_gemm_input_defs[0] = A_node.MutableInputDefs()[0];
|
||||
// make sure all consumers are gemm nodes to avoid possible double transpose
|
||||
std::vector<const Node*> gemm_nodes = graph_utils::FindChildrenByType(*A_node_ptr, "Gemm");
|
||||
if (gemm_nodes.size() == A_node_ptr->GetOutputEdgesCount()) {
|
||||
Node& A_node = *graph.GetNode(A_node_ptr->Index());
|
||||
transA = !transA;
|
||||
if (A_node.GetOutputEdgesCount() > 1) {
|
||||
// remove only the edge between the Transpose and Gemm nodes, the Transpose won't be removed
|
||||
// since it's still connected to other Gemm. When transformation for the last connected Gemm is
|
||||
// being processed, it would fall into the else {} below to remove the Transpose node
|
||||
int output_idx = graph_utils::GetNodeOutputIndexFromOutputName(A_node, gemm_node.MutableInputDefs()[0]->Name());
|
||||
graph.RemoveEdge(A_node.Index(), gemm_node.Index(), output_idx, 0);
|
||||
} else {
|
||||
nodes_to_remove.push_back(A_node);
|
||||
}
|
||||
new_gemm_input_defs[0] = A_node.MutableInputDefs()[0];
|
||||
}
|
||||
}
|
||||
// check if input B is a Transpose
|
||||
if (B_node_ptr != nullptr && B_node_ptr->OpType() == "Transpose") {
|
||||
Node& B_node = *graph.GetNode(B_node_ptr->Index());
|
||||
transB = !transB;
|
||||
nodes_to_remove.push_back(B_node);
|
||||
new_gemm_input_defs[1] = B_node.MutableInputDefs()[0];
|
||||
std::vector<const Node*> gemm_nodes = graph_utils::FindChildrenByType(*B_node_ptr, "Gemm");
|
||||
if (gemm_nodes.size() == B_node_ptr->GetOutputEdgesCount()) {
|
||||
Node& B_node = *graph.GetNode(B_node_ptr->Index());
|
||||
transB = !transB;
|
||||
if (B_node.GetOutputEdgesCount() > 1) {
|
||||
int output_idx = graph_utils::GetNodeOutputIndexFromOutputName(B_node, gemm_node.MutableInputDefs()[1]->Name());
|
||||
graph.RemoveEdge(B_node.Index(), gemm_node.Index(), output_idx, 1);
|
||||
} else {
|
||||
nodes_to_remove.push_back(B_node);
|
||||
}
|
||||
new_gemm_input_defs[1] = B_node.MutableInputDefs()[0];
|
||||
}
|
||||
}
|
||||
|
||||
nodes_to_remove.push_back(gemm_node);
|
||||
|
|
@ -82,11 +102,14 @@ bool GemmTransposeFusion::SatisfyCondition(const Graph& graph, const Node& node,
|
|||
// Fusion can be applied if there is a transpose at either of the inputs
|
||||
for (auto node_it = node.InputNodesBegin(); node_it != node.InputNodesEnd(); ++node_it) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(*node_it, "Transpose", {1, 13}) &&
|
||||
node_it->GetOutputEdgesCount() == 1 &&
|
||||
!graph.NodeProducesGraphOutput(*node_it) &&
|
||||
// Make sure the two nodes do not span execution providers.
|
||||
node_it->GetExecutionProviderType() == node.GetExecutionProviderType()) {
|
||||
return true;
|
||||
// acceptable if all consumer(s) are gemm node(s)
|
||||
std::vector<const Node*> gemm_nodes = graph_utils::FindChildrenByType(*node_it, "Gemm");
|
||||
if (gemm_nodes.size() == node_it->GetOutputEdgesCount()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1219,6 +1219,91 @@ TEST_F(GraphTransformationTests, GemmTransposeFusion2Inputs) {
|
|||
ASSERT_TRUE(new_input_defs[1]->Name() == "B");
|
||||
}
|
||||
|
||||
// (A')'B' = AB' where transpose has multiple consumers
|
||||
TEST_F(GraphTransformationTests, GemmTransposeFusion2OutputsFromTranspose) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/gemm_transpose_2outputs_from_transpose.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Transpose"], 2);
|
||||
ASSERT_EQ(op_to_count["Gemm"], 1);
|
||||
ASSERT_EQ(op_to_count["Identity"], 1);
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(std::make_unique<GemmTransposeFusion>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Transpose"], 1);
|
||||
ASSERT_EQ(op_to_count["Gemm"], 1);
|
||||
ASSERT_EQ(op_to_count["Identity"], 1);
|
||||
|
||||
auto gemm_node =
|
||||
std::find_if(
|
||||
graph.Nodes().cbegin(), graph.Nodes().cend(),
|
||||
[](const Node& node) { return node.Name() == "Gemm_transformed"; });
|
||||
|
||||
auto& node = *gemm_node;
|
||||
ASSERT_TRUE(node.OpType() == "Gemm");
|
||||
ASSERT_TRUE(static_cast<bool>(node.GetAttributes().at("transA").i()));
|
||||
ASSERT_TRUE(static_cast<bool>(node.GetAttributes().at("transB").i()));
|
||||
auto new_input_defs = node.InputDefs();
|
||||
ASSERT_TRUE(new_input_defs[0]->Name() == "tp0");
|
||||
ASSERT_TRUE(new_input_defs[1]->Name() == "B");
|
||||
}
|
||||
|
||||
// (A')'B' = AB' and (B')'C = BC where transpose has multiple consumers
|
||||
TEST_F(GraphTransformationTests, GemmTransposeFusion2OutputsFromTransposeTo2Gemms) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/gemm_transpose_2outputs_from_transpose_to_2gemms.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Transpose"], 2);
|
||||
ASSERT_EQ(op_to_count["Gemm"], 2);
|
||||
ASSERT_EQ(op_to_count["Identity"], 1);
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(std::make_unique<GemmTransposeFusion>());
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Transpose"], 1);
|
||||
ASSERT_EQ(op_to_count["Gemm"], 2);
|
||||
ASSERT_EQ(op_to_count["Identity"], 1);
|
||||
|
||||
auto gemm1_node =
|
||||
std::find_if(
|
||||
graph.Nodes().cbegin(), graph.Nodes().cend(),
|
||||
[](const Node& node) { return node.Name() == "Gemm1_transformed"; });
|
||||
|
||||
auto& node1 = *gemm1_node;
|
||||
ASSERT_TRUE(node1.OpType() == "Gemm");
|
||||
ASSERT_TRUE(static_cast<bool>(node1.GetAttributes().at("transA").i()));
|
||||
ASSERT_TRUE(static_cast<bool>(node1.GetAttributes().at("transB").i()));
|
||||
auto new_input_defs1 = node1.InputDefs();
|
||||
ASSERT_TRUE(new_input_defs1[0]->Name() == "tp0");
|
||||
ASSERT_TRUE(new_input_defs1[1]->Name() == "B");
|
||||
|
||||
auto gemm2_node =
|
||||
std::find_if(
|
||||
graph.Nodes().cbegin(), graph.Nodes().cend(),
|
||||
[](const Node& node) { return node.Name() == "Gemm2_transformed"; });
|
||||
|
||||
auto& node2 = *gemm2_node;
|
||||
ASSERT_TRUE(node2.OpType() == "Gemm");
|
||||
ASSERT_FALSE(static_cast<bool>(node2.GetAttributes().at("transA").i()));
|
||||
ASSERT_FALSE(static_cast<bool>(node2.GetAttributes().at("transB").i()));
|
||||
auto new_input_defs2 = node2.InputDefs();
|
||||
ASSERT_TRUE(new_input_defs2[0]->Name() == "B");
|
||||
ASSERT_TRUE(new_input_defs2[1]->Name() == "C");
|
||||
}
|
||||
|
||||
// (A'B)' = B'A
|
||||
TEST_F(GraphTransformationTests, GemmTransposeFusionOutput) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/gemm_transpose_output_transposed.onnx";
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/gemm_transpose_2outputs_from_transpose.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/gemm_transpose_2outputs_from_transpose.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/gemm_transpose_2outputs_from_transpose_to_2gemms.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/gemm_transpose_2outputs_from_transpose_to_2gemms.onnx
vendored
Normal file
Binary file not shown.
75
onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen_2.py
vendored
Normal file
75
onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen_2.py
vendored
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
from onnx import OperatorSetIdProto
|
||||
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 12
|
||||
# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
|
||||
onnxdomain.domain = ""
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = "com.microsoft"
|
||||
opsets = [onnxdomain, msdomain]
|
||||
|
||||
|
||||
def save(model_path, nodes, inputs, outputs, initializers):
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"TransposeGemmTest",
|
||||
inputs, outputs, initializers)
|
||||
|
||||
model = helper.make_model(
|
||||
graph, opset_imports=opsets, producer_name="onnxruntime-test")
|
||||
|
||||
onnx.save(model, model_path)
|
||||
|
||||
# (A')'B' = AB'
|
||||
def gemm_transpose_2outputs_from_transpose(model_path):
|
||||
nodes = [
|
||||
helper.make_node("Transpose", ["A"], ["tp0"], "TransposeA"),
|
||||
helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"),
|
||||
helper.make_node("Gemm", ["tp0", "tp1"], ["output"], "Gemm", alpha=3.0, transA=1),
|
||||
helper.make_node("Identity", ["tp0"], ["output2"], "IdentityAt"),
|
||||
]
|
||||
|
||||
inputs = [
|
||||
helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']),
|
||||
helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K'])
|
||||
]
|
||||
|
||||
outputs = [
|
||||
helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']),
|
||||
helper.make_tensor_value_info("output2", TensorProto.FLOAT, ['K', 'M'])
|
||||
]
|
||||
|
||||
save(model_path, nodes, inputs, outputs, [])
|
||||
|
||||
|
||||
# (A')'B' = AB' and (B')'C = BC
|
||||
def gemm_transpose_2outputs_from_transpose_to_2gemms(model_path):
|
||||
nodes = [
|
||||
helper.make_node("Transpose", ["A"], ["tp0"], "TransposeA"),
|
||||
helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"),
|
||||
helper.make_node("Gemm", ["tp0", "tp1"], ["output"], "Gemm1", alpha=3.0, transA=1),
|
||||
helper.make_node("Gemm", ["tp1", "C"], ["output3"], "Gemm2", alpha=3.0, transA=1),
|
||||
helper.make_node("Identity", ["tp0"], ["output2"], "IdentityAt"),
|
||||
]
|
||||
|
||||
inputs = [
|
||||
helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']),
|
||||
helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K']),
|
||||
helper.make_tensor_value_info("C", TensorProto.FLOAT, ['K', 'L'])
|
||||
]
|
||||
|
||||
outputs = [
|
||||
helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']),
|
||||
helper.make_tensor_value_info("output2", TensorProto.FLOAT, ['K', 'M']),
|
||||
helper.make_tensor_value_info("output3", TensorProto.FLOAT, ['N', 'L'])
|
||||
]
|
||||
|
||||
save(model_path, nodes, inputs, outputs, [])
|
||||
|
||||
gemm_transpose_2outputs_from_transpose("gemm_transpose_2outputs_from_transpose.onnx")
|
||||
gemm_transpose_2outputs_from_transpose_to_2gemms("gemm_transpose_2outputs_from_transpose_to_2gemms.onnx")
|
||||
|
||||
Loading…
Reference in a new issue