mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
Interchange Cast and Transpose operations to facilitate Transpose-MatMul fusion (#6924)
* Added support to interchange Cast and Transpose operations. * Added ONNX models for the Transpose-Cast-MatMul fusion testcases. * Added python code to generate the ONNX models required for testing Transpose+Cast+Matmul fusion to Cast+FusedMatMul. * Added diagram of the Transpose+MatMul fusion documentation
This commit is contained in:
parent
91c6a330c0
commit
48eebed869
5 changed files with 252 additions and 0 deletions
|
|
@ -86,6 +86,177 @@ static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_
|
|||
}
|
||||
}
|
||||
|
||||
/* GetTransposeNodeFromCast: Interchange Cast and Transpose nodes in the graph and return Transpose node if possible
|
||||
* Requirements to interchange Cast and Transpose nodes changing the order of the operations.
|
||||
* 1. Both Cast and Transpose are one-output nodes (assuming both have one-input only)
|
||||
* 2. Transpose only feeds the Cast node (and no other node)
|
||||
* 3. Cast only feeds the MalMul node (and no other node)
|
||||
*
|
||||
* Transform the following pattern
|
||||
* |
|
||||
* _____|______
|
||||
* |Transpose |
|
||||
* |__________|
|
||||
* |
|
||||
* |
|
||||
* _____V______
|
||||
* | Cast |
|
||||
* |__________|
|
||||
* |
|
||||
* V
|
||||
*
|
||||
* to
|
||||
* |
|
||||
* _____|______
|
||||
* | Cast |
|
||||
* |__________|
|
||||
* |
|
||||
* |
|
||||
* _____V______
|
||||
* | Transpose|
|
||||
* |__________|
|
||||
* |
|
||||
* V
|
||||
*/
|
||||
static Node* GetTransposeNodeFromCast(Graph& graph, Node* cast) {
|
||||
|
||||
ORT_ENFORCE(cast != nullptr);
|
||||
auto transpose = GetTransposeNodeFromOutput(graph, *cast->MutableInputDefs()[0]);
|
||||
if (transpose == nullptr || cast->GetOutputEdgesCount() != 1 || transpose->GetOutputEdgesCount() != 1) {
|
||||
return nullptr;
|
||||
}
|
||||
NodeArg* cast_output = cast->MutableOutputDefs()[0];
|
||||
NodeArg* transpose_input = transpose->MutableInputDefs()[0];
|
||||
|
||||
// Create a new NodeArg to feed the output from the new Cast to the new Transpose.
|
||||
// The shape of the new NodeArg is same as the original input to Transport but type
|
||||
// should match that of the output from the original Cast.
|
||||
|
||||
auto new_cast_output_type_proto = *transpose_input->TypeAsProto();
|
||||
const ONNX_NAMESPACE::TensorProto_DataType element_type =
|
||||
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(cast_output->TypeAsProto()->tensor_type().elem_type());
|
||||
new_cast_output_type_proto.mutable_tensor_type()->set_elem_type(element_type);
|
||||
auto& new_cast_output = graph.GetOrCreateNodeArg(cast_output->Name() + "_transformed", &new_cast_output_type_proto);
|
||||
|
||||
const std::vector<NodeArg*> new_cast_input_defs {transpose_input};
|
||||
const std::vector<NodeArg*> new_cast_output_defs {&new_cast_output};
|
||||
const std::vector<NodeArg*> new_transpose_input_defs = {&new_cast_output};
|
||||
const std::vector<NodeArg*> new_transpose_output_defs = {cast_output};
|
||||
|
||||
(void) graph.AddNode(graph.GenerateNodeName(cast->Name() + "_transformed"),
|
||||
cast->OpType(),
|
||||
"Created a new Cast node to interchange Cast and Transpose nodes",
|
||||
new_cast_input_defs,
|
||||
new_cast_output_defs,
|
||||
&cast->GetAttributes(),
|
||||
cast->Domain());
|
||||
|
||||
Node& new_transpose = graph.AddNode(graph.GenerateNodeName(transpose->Name() + "_transformed"),
|
||||
transpose->OpType(),
|
||||
"Created a new Transpose node to interchange Cast and Transpose nodes",
|
||||
new_transpose_input_defs,
|
||||
new_transpose_output_defs,
|
||||
&transpose->GetAttributes(),
|
||||
transpose->Domain());
|
||||
|
||||
graph_utils::RemoveNodeOutputEdges(graph, *cast);
|
||||
graph_utils::RemoveNodeOutputEdges(graph, *transpose);
|
||||
graph.RemoveNode(cast->Index());
|
||||
graph.RemoveNode(transpose->Index());
|
||||
return &new_transpose;
|
||||
}
|
||||
|
||||
/*********************************************************************************************
|
||||
|
||||
Case I: The followin is a scenario where Transpose output feeds MatMul. The Transpose input can be either on the left or right.
|
||||
The input graph
|
||||
__________ __________
|
||||
| input0 | | input1 |
|
||||
|________| |________|
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
_____V______ |
|
||||
|Transpose | |
|
||||
|__________| |
|
||||
| |
|
||||
| |
|
||||
|______________ _____________|
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
__V___________V__
|
||||
| MatMul |
|
||||
|_______________|
|
||||
|
|
||||
V
|
||||
is transformed to the following
|
||||
|
||||
__________ __________
|
||||
| input0 | | input1 |
|
||||
|________| |________|
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
|_____________ _____________|
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
__V___________V__
|
||||
| FusedMatMul |
|
||||
|_______________|
|
||||
|
|
||||
V
|
||||
|
||||
Case II: The output of Tanspose feeds Cast and the output from the Cast feeds MatMul
|
||||
The input graph
|
||||
__________ __________
|
||||
| input0 | | input1 |
|
||||
|________| |________|
|
||||
| |
|
||||
| |
|
||||
_____V______ |
|
||||
|Transpose | |
|
||||
|__________| |
|
||||
| |
|
||||
| |
|
||||
_____V______ |
|
||||
| Cast | |
|
||||
|__________| |
|
||||
| |
|
||||
|______________ _____________|
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
__V___________V__
|
||||
| MatMul |
|
||||
|_______________|
|
||||
|
|
||||
V
|
||||
is transformed to the following
|
||||
|
||||
__________ __________
|
||||
| input0 | | input1 |
|
||||
|________| |________|
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
_____V______ |
|
||||
| Cast | |
|
||||
|__________| |
|
||||
| |
|
||||
|______________ _____________|
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
__V___________V__
|
||||
| FusedMatMul |
|
||||
|_______________|
|
||||
|
|
||||
V
|
||||
|
||||
********************************************************************************************************************/
|
||||
|
||||
Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
|
@ -109,6 +280,19 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
|
|||
NodeArg* right_input = node.MutableInputDefs()[1];
|
||||
auto right = GetTransposeNodeFromOutput(graph, *right_input);
|
||||
|
||||
if (!left && !right) {
|
||||
Node* left_node = graph.GetMutableProducerNode(left_input->Name());
|
||||
if (left_node && left_node->OpType() == "Cast") {
|
||||
left = GetTransposeNodeFromCast(graph, left_node);
|
||||
}
|
||||
if (!left) {
|
||||
Node* right_node = graph.GetMutableProducerNode(right_input->Name());
|
||||
if (right_node && right_node->OpType() == "Cast") {
|
||||
right = GetTransposeNodeFromCast(graph, right_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!left && !right) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -844,6 +844,28 @@ TEST_F(GraphTransformationTests, TransposeMatmulFusion) {
|
|||
ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, TransposeCastMatmulFusion) {
|
||||
const std::vector<PathString> model_uris = {
|
||||
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion0.onnx", // Test fusion from the right input
|
||||
MODEL_FOLDER "fusion/transpose_cast_matmul_4d_fusion1.onnx" // Test fusion from the left input
|
||||
};
|
||||
for (const auto& model_uri : model_uris) {
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MatmulTransposeFusion>(), TransformerLevel::Level1);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Transpose"] == 0);
|
||||
ASSERT_TRUE(op_to_count["MatMul"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Cast"] == 2);
|
||||
ASSERT_TRUE(op_to_count["com.microsoft.FusedMatMul"] == 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, TransposeMatmulFusionOnTwoTranspose) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/transpose_matmul_4d_fusion_2_transpose.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion0.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion0.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion1.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/transpose_cast_matmul_4d_fusion1.onnx
vendored
Normal file
Binary file not shown.
|
|
@ -126,3 +126,49 @@ def gen_with_preserved_transpose(model_path):
|
|||
|
||||
gen_with_preserved_transpose(
|
||||
"transpose_matmul_2d_fusion_with_preserved_transpose.onnx")
|
||||
|
||||
|
||||
def gen_transpose_fusion_with_cast(model_path):
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
"Cast",
|
||||
["input_1"],
|
||||
["casted_input_1"],
|
||||
to = 10
|
||||
),
|
||||
helper.make_node(
|
||||
"Transpose",
|
||||
["input_0"],
|
||||
["transposed_input_0"],
|
||||
perm = [0, 1, 3, 2]),
|
||||
helper.make_node(
|
||||
"Cast",
|
||||
["transposed_input_0"],
|
||||
["transposed_casted_input_0"],
|
||||
to = 10),
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
["transposed_casted_input_0", "casted_input_1"],
|
||||
["output_0"])
|
||||
]
|
||||
|
||||
inputs = [
|
||||
helper.make_tensor_value_info(
|
||||
"input_0", TensorProto.FLOAT, [3, 2, 'K', 'M']),
|
||||
helper.make_tensor_value_info(
|
||||
"input_1", TensorProto.FLOAT, [3, 2, 'K', 'N'])
|
||||
]
|
||||
|
||||
outputs = [
|
||||
helper.make_tensor_value_info(
|
||||
"output_0", TensorProto.FLOAT16, [3, 2, 'M', 'N'])
|
||||
]
|
||||
|
||||
save(model_path + "0.onnx", nodes, inputs, outputs, [])
|
||||
# Re-arragne nodes so that the transpose is on left input of matmul
|
||||
nodes = nodes[1:3] + nodes[0:1] + nodes[3:]
|
||||
save(model_path + "1.onnx", nodes, inputs, outputs, [])
|
||||
|
||||
|
||||
gen_transpose_fusion_with_cast(
|
||||
"transpose_cast_matmul_4d_fusion")
|
||||
|
|
|
|||
Loading…
Reference in a new issue