Gemm/Transpose fusion - additional pattern coverage (#8941)

* gemm transpose fixes

* enforce condition

* add comments

* rm redundant code

Co-authored-by: Ethan Tao <ettao@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
ytaous 2021-09-03 15:24:47 -07:00 committed by GitHub
parent eebcc20f10
commit 53eb79f9f6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 193 additions and 10 deletions

View file

@ -26,17 +26,37 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m
// check if input A is a Transpose
if (A_node_ptr != nullptr && A_node_ptr->OpType() == "Transpose") {
Node& A_node = *graph.GetNode(A_node_ptr->Index());
transA = !transA;
nodes_to_remove.push_back(A_node);
new_gemm_input_defs[0] = A_node.MutableInputDefs()[0];
// make sure all consumers are gemm nodes to avoid possible double transpose
std::vector<const Node*> gemm_nodes = graph_utils::FindChildrenByType(*A_node_ptr, "Gemm");
if (gemm_nodes.size() == A_node_ptr->GetOutputEdgesCount()) {
Node& A_node = *graph.GetNode(A_node_ptr->Index());
transA = !transA;
if (A_node.GetOutputEdgesCount() > 1) {
// remove only the edge between the Transpose and Gemm nodes, the Transpose won't be removed
// since it's still connected to other Gemm. When transformation for the last connected Gemm is
// being processed, it would fall into the else {} below to remove the Transpose node
int output_idx = graph_utils::GetNodeOutputIndexFromOutputName(A_node, gemm_node.MutableInputDefs()[0]->Name());
graph.RemoveEdge(A_node.Index(), gemm_node.Index(), output_idx, 0);
} else {
nodes_to_remove.push_back(A_node);
}
new_gemm_input_defs[0] = A_node.MutableInputDefs()[0];
}
}
// check if input B is a Transpose
if (B_node_ptr != nullptr && B_node_ptr->OpType() == "Transpose") {
Node& B_node = *graph.GetNode(B_node_ptr->Index());
transB = !transB;
nodes_to_remove.push_back(B_node);
new_gemm_input_defs[1] = B_node.MutableInputDefs()[0];
std::vector<const Node*> gemm_nodes = graph_utils::FindChildrenByType(*B_node_ptr, "Gemm");
if (gemm_nodes.size() == B_node_ptr->GetOutputEdgesCount()) {
Node& B_node = *graph.GetNode(B_node_ptr->Index());
transB = !transB;
if (B_node.GetOutputEdgesCount() > 1) {
int output_idx = graph_utils::GetNodeOutputIndexFromOutputName(B_node, gemm_node.MutableInputDefs()[1]->Name());
graph.RemoveEdge(B_node.Index(), gemm_node.Index(), output_idx, 1);
} else {
nodes_to_remove.push_back(B_node);
}
new_gemm_input_defs[1] = B_node.MutableInputDefs()[0];
}
}
nodes_to_remove.push_back(gemm_node);
@ -82,11 +102,14 @@ bool GemmTransposeFusion::SatisfyCondition(const Graph& graph, const Node& node,
// Fusion can be applied if there is a transpose at either of the inputs
for (auto node_it = node.InputNodesBegin(); node_it != node.InputNodesEnd(); ++node_it) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(*node_it, "Transpose", {1, 13}) &&
node_it->GetOutputEdgesCount() == 1 &&
!graph.NodeProducesGraphOutput(*node_it) &&
// Make sure the two nodes do not span execution providers.
node_it->GetExecutionProviderType() == node.GetExecutionProviderType()) {
return true;
// acceptable if all consumer(s) are gemm node(s)
std::vector<const Node*> gemm_nodes = graph_utils::FindChildrenByType(*node_it, "Gemm");
if (gemm_nodes.size() == node_it->GetOutputEdgesCount()) {
return true;
}
}
}

View file

@ -1219,6 +1219,91 @@ TEST_F(GraphTransformationTests, GemmTransposeFusion2Inputs) {
ASSERT_TRUE(new_input_defs[1]->Name() == "B");
}
// (A')'B' = AB' where transpose has multiple consumers
TEST_F(GraphTransformationTests, GemmTransposeFusion2OutputsFromTranspose) {
auto model_uri = MODEL_FOLDER "fusion/gemm_transpose_2outputs_from_transpose.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Transpose"], 2);
ASSERT_EQ(op_to_count["Gemm"], 1);
ASSERT_EQ(op_to_count["Identity"], 1);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(std::make_unique<GemmTransposeFusion>());
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Transpose"], 1);
ASSERT_EQ(op_to_count["Gemm"], 1);
ASSERT_EQ(op_to_count["Identity"], 1);
auto gemm_node =
std::find_if(
graph.Nodes().cbegin(), graph.Nodes().cend(),
[](const Node& node) { return node.Name() == "Gemm_transformed"; });
auto& node = *gemm_node;
ASSERT_TRUE(node.OpType() == "Gemm");
ASSERT_TRUE(static_cast<bool>(node.GetAttributes().at("transA").i()));
ASSERT_TRUE(static_cast<bool>(node.GetAttributes().at("transB").i()));
auto new_input_defs = node.InputDefs();
ASSERT_TRUE(new_input_defs[0]->Name() == "tp0");
ASSERT_TRUE(new_input_defs[1]->Name() == "B");
}
// (A')'B' = AB' and (B')'C = BC where transpose has multiple consumers
TEST_F(GraphTransformationTests, GemmTransposeFusion2OutputsFromTransposeTo2Gemms) {
auto model_uri = MODEL_FOLDER "fusion/gemm_transpose_2outputs_from_transpose_to_2gemms.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Transpose"], 2);
ASSERT_EQ(op_to_count["Gemm"], 2);
ASSERT_EQ(op_to_count["Identity"], 1);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(std::make_unique<GemmTransposeFusion>());
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Transpose"], 1);
ASSERT_EQ(op_to_count["Gemm"], 2);
ASSERT_EQ(op_to_count["Identity"], 1);
auto gemm1_node =
std::find_if(
graph.Nodes().cbegin(), graph.Nodes().cend(),
[](const Node& node) { return node.Name() == "Gemm1_transformed"; });
auto& node1 = *gemm1_node;
ASSERT_TRUE(node1.OpType() == "Gemm");
ASSERT_TRUE(static_cast<bool>(node1.GetAttributes().at("transA").i()));
ASSERT_TRUE(static_cast<bool>(node1.GetAttributes().at("transB").i()));
auto new_input_defs1 = node1.InputDefs();
ASSERT_TRUE(new_input_defs1[0]->Name() == "tp0");
ASSERT_TRUE(new_input_defs1[1]->Name() == "B");
auto gemm2_node =
std::find_if(
graph.Nodes().cbegin(), graph.Nodes().cend(),
[](const Node& node) { return node.Name() == "Gemm2_transformed"; });
auto& node2 = *gemm2_node;
ASSERT_TRUE(node2.OpType() == "Gemm");
ASSERT_FALSE(static_cast<bool>(node2.GetAttributes().at("transA").i()));
ASSERT_FALSE(static_cast<bool>(node2.GetAttributes().at("transB").i()));
auto new_input_defs2 = node2.InputDefs();
ASSERT_TRUE(new_input_defs2[0]->Name() == "B");
ASSERT_TRUE(new_input_defs2[1]->Name() == "C");
}
// (A'B)' = B'A
TEST_F(GraphTransformationTests, GemmTransposeFusionOutput) {
auto model_uri = MODEL_FOLDER "fusion/gemm_transpose_output_transposed.onnx";

View file

@ -0,0 +1,75 @@
import onnx
from onnx import helper
from onnx import TensorProto
from onnx import OperatorSetIdProto
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
onnxdomain.domain = ""
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = "com.microsoft"
opsets = [onnxdomain, msdomain]
def save(model_path, nodes, inputs, outputs, initializers):
graph = helper.make_graph(
nodes,
"TransposeGemmTest",
inputs, outputs, initializers)
model = helper.make_model(
graph, opset_imports=opsets, producer_name="onnxruntime-test")
onnx.save(model, model_path)
# (A')'B' = AB'
def gemm_transpose_2outputs_from_transpose(model_path):
nodes = [
helper.make_node("Transpose", ["A"], ["tp0"], "TransposeA"),
helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"),
helper.make_node("Gemm", ["tp0", "tp1"], ["output"], "Gemm", alpha=3.0, transA=1),
helper.make_node("Identity", ["tp0"], ["output2"], "IdentityAt"),
]
inputs = [
helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']),
helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K'])
]
outputs = [
helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']),
helper.make_tensor_value_info("output2", TensorProto.FLOAT, ['K', 'M'])
]
save(model_path, nodes, inputs, outputs, [])
# (A')'B' = AB' and (B')'C = BC
def gemm_transpose_2outputs_from_transpose_to_2gemms(model_path):
nodes = [
helper.make_node("Transpose", ["A"], ["tp0"], "TransposeA"),
helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"),
helper.make_node("Gemm", ["tp0", "tp1"], ["output"], "Gemm1", alpha=3.0, transA=1),
helper.make_node("Gemm", ["tp1", "C"], ["output3"], "Gemm2", alpha=3.0, transA=1),
helper.make_node("Identity", ["tp0"], ["output2"], "IdentityAt"),
]
inputs = [
helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']),
helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K']),
helper.make_tensor_value_info("C", TensorProto.FLOAT, ['K', 'L'])
]
outputs = [
helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']),
helper.make_tensor_value_info("output2", TensorProto.FLOAT, ['K', 'M']),
helper.make_tensor_value_info("output3", TensorProto.FLOAT, ['N', 'L'])
]
save(model_path, nodes, inputs, outputs, [])
gemm_transpose_2outputs_from_transpose("gemm_transpose_2outputs_from_transpose.onnx")
gemm_transpose_2outputs_from_transpose_to_2gemms("gemm_transpose_2outputs_from_transpose_to_2gemms.onnx")