Interchange Cast and Transpose operations to facilitate Transpose-MatMul fusion (#6924)

* Added support to interchange Cast and Transpose operations.

* Added ONNX models for the Transpose-Cast-MatMul fusion testcases.

* Added python code to generate the ONNX models required for testing Transpose+Cast+Matmul fusion to Cast+FusedMatMul.

* Added diagram of the Transpose+MatMul fusion documentation
This commit is contained in:
satyajandhyala 2021-03-09 08:54:56 -08:00 committed by GitHub
parent 91c6a330c0
commit 48eebed869
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 252 additions and 0 deletions

View file

@ -86,6 +86,177 @@ static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_
}
}
/* GetTransposeNodeFromCast: Interchange Cast and Transpose nodes in the graph and return Transpose node if possible
* Requirements to interchange Cast and Transpose nodes changing the order of the operations.
* 1. Both Cast and Transpose are one-output nodes (assuming both have one-input only)
* 2. Transpose only feeds the Cast node (and no other node)
* 3. Cast only feeds the MalMul node (and no other node)
*
* Transform the following pattern
* |
* _____|______
* |Transpose |
* |__________|
* |
* |
* _____V______
* | Cast |
* |__________|
* |
* V
*
* to
* |
* _____|______
* | Cast |
* |__________|
* |
* |
* _____V______
* | Transpose|
* |__________|
* |
* V
*/
static Node* GetTransposeNodeFromCast(Graph& graph, Node* cast) {
ORT_ENFORCE(cast != nullptr);
auto transpose = GetTransposeNodeFromOutput(graph, *cast->MutableInputDefs()[0]);
if (transpose == nullptr || cast->GetOutputEdgesCount() != 1 || transpose->GetOutputEdgesCount() != 1) {
return nullptr;
}
NodeArg* cast_output = cast->MutableOutputDefs()[0];
NodeArg* transpose_input = transpose->MutableInputDefs()[0];
// Create a new NodeArg to feed the output from the new Cast to the new Transpose.
// The shape of the new NodeArg is same as the original input to Transport but type
// should match that of the output from the original Cast.
auto new_cast_output_type_proto = *transpose_input->TypeAsProto();
const ONNX_NAMESPACE::TensorProto_DataType element_type =
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(cast_output->TypeAsProto()->tensor_type().elem_type());
new_cast_output_type_proto.mutable_tensor_type()->set_elem_type(element_type);
auto& new_cast_output = graph.GetOrCreateNodeArg(cast_output->Name() + "_transformed", &new_cast_output_type_proto);
const std::vector<NodeArg*> new_cast_input_defs {transpose_input};
const std::vector<NodeArg*> new_cast_output_defs {&new_cast_output};
const std::vector<NodeArg*> new_transpose_input_defs = {&new_cast_output};
const std::vector<NodeArg*> new_transpose_output_defs = {cast_output};
(void) graph.AddNode(graph.GenerateNodeName(cast->Name() + "_transformed"),
cast->OpType(),
"Created a new Cast node to interchange Cast and Transpose nodes",
new_cast_input_defs,
new_cast_output_defs,
&cast->GetAttributes(),
cast->Domain());
Node& new_transpose = graph.AddNode(graph.GenerateNodeName(transpose->Name() + "_transformed"),
transpose->OpType(),
"Created a new Transpose node to interchange Cast and Transpose nodes",
new_transpose_input_defs,
new_transpose_output_defs,
&transpose->GetAttributes(),
transpose->Domain());
graph_utils::RemoveNodeOutputEdges(graph, *cast);
graph_utils::RemoveNodeOutputEdges(graph, *transpose);
graph.RemoveNode(cast->Index());
graph.RemoveNode(transpose->Index());
return &new_transpose;
}
/*********************************************************************************************
Case I: The followin is a scenario where Transpose output feeds MatMul. The Transpose input can be either on the left or right.
The input graph
__________ __________
| input0 | | input1 |
|________| |________|
| |
| |
| |
_____V______ |
|Transpose | |
|__________| |
| |
| |
|______________ _____________|
| |
| |
| |
__V___________V__
| MatMul |
|_______________|
|
V
is transformed to the following
__________ __________
| input0 | | input1 |
|________| |________|
| |
| |
| |
|_____________ _____________|
| |
| |
| |
__V___________V__
| FusedMatMul |
|_______________|
|
V
Case II: The output of Tanspose feeds Cast and the output from the Cast feeds MatMul
The input graph
__________ __________
| input0 | | input1 |
|________| |________|
| |
| |
_____V______ |
|Transpose | |
|__________| |
| |
| |
_____V______ |
| Cast | |
|__________| |
| |
|______________ _____________|
| |
| |
| |
__V___________V__
| MatMul |
|_______________|
|
V
is transformed to the following
__________ __________
| input0 | | input1 |
|________| |________|
| |
| |
| |
_____V______ |
| Cast | |
|__________| |
| |
|______________ _____________|
| |
| |
| |
__V___________V__
| FusedMatMul |
|_______________|
|
V
********************************************************************************************************************/
Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
@ -109,6 +280,19 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
NodeArg* right_input = node.MutableInputDefs()[1];
auto right = GetTransposeNodeFromOutput(graph, *right_input);
if (!left && !right) {
Node* left_node = graph.GetMutableProducerNode(left_input->Name());
if (left_node && left_node->OpType() == "Cast") {
left = GetTransposeNodeFromCast(graph, left_node);
}
if (!left) {
Node* right_node = graph.GetMutableProducerNode(right_input->Name());
if (right_node && right_node->OpType() == "Cast") {
right = GetTransposeNodeFromCast(graph, right_node);
}
}
}
if (!left && !right) {
continue;
}

View file

@ -844,6 +844,28 @@ TEST_F(GraphTransformationTests, TransposeMatmulFusion) {
ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 1);
}
TEST_F(GraphTransformationTests, TransposeCastMatmulFusion) {
const std::vector<PathString> model_uris = {
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion0.onnx", // Test fusion from the right input
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion1.onnx" // Test fusion from the left input
};
for (const auto& model_uri : model_uris) {
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<MatmulTransposeFusion>(), TransformerLevel::Level1);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Transpose"] == 0);
ASSERT_TRUE(op_to_count["MatMul"] == 0);
ASSERT_TRUE(op_to_count["Cast"] == 2);
ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 1);
}
}
TEST_F(GraphTransformationTests, TransposeMatmulFusionOnTwoTranspose) {
auto model_uri = MODEL_FOLDER "fusion/transpose_matmul_4d_fusion_2_transpose.onnx";
std::shared_ptr<Model> p_model;

View file

@ -126,3 +126,49 @@ def gen_with_preserved_transpose(model_path):
gen_with_preserved_transpose(
"transpose_matmul_2d_fusion_with_preserved_transpose.onnx")
def gen_transpose_fusion_with_cast(model_path):
nodes = [
helper.make_node(
"Cast",
["input_1"],
["casted_input_1"],
to = 10
),
helper.make_node(
"Transpose",
["input_0"],
["transposed_input_0"],
perm = [0, 1, 3, 2]),
helper.make_node(
"Cast",
["transposed_input_0"],
["transposed_casted_input_0"],
to = 10),
helper.make_node(
"MatMul",
["transposed_casted_input_0", "casted_input_1"],
["output_0"])
]
inputs = [
helper.make_tensor_value_info(
"input_0", TensorProto.FLOAT, [3, 2, 'K', 'M']),
helper.make_tensor_value_info(
"input_1", TensorProto.FLOAT, [3, 2, 'K', 'N'])
]
outputs = [
helper.make_tensor_value_info(
"output_0", TensorProto.FLOAT16, [3, 2, 'M', 'N'])
]
save(model_path + "0.onnx", nodes, inputs, outputs, [])
# Re-arragne nodes so that the transpose is on left input of matmul
nodes = nodes[1:3] + nodes[0:1] + nodes[3:]
save(model_path + "1.onnx", nodes, inputs, outputs, [])
gen_transpose_fusion_with_cast(
"transpose_cast_matmul_4d_fusion")