diff --git a/onnxruntime/python/tools/transformers/fusion_transpose.py b/onnxruntime/python/tools/transformers/fusion_transpose.py index 84961f799a..8c4f867bdb 100644 --- a/onnxruntime/python/tools/transformers/fusion_transpose.py +++ b/onnxruntime/python/tools/transformers/fusion_transpose.py @@ -8,7 +8,7 @@ from typing import Dict, List from fusion_base import Fusion from fusion_utils import FusionUtils -from onnx import NodeProto, helper +from onnx import NodeProto, TensorProto, helper from onnx_model import OnnxModel logger = getLogger(__name__) @@ -80,3 +80,83 @@ class FusionTranspose(Fusion): self.nodes_to_remove.append(transpose_a) transpose_b.ClearField("attribute") transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)]) + + +class FusionInsertTranspose(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "", "GroupNorm") + + def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): + """Append a Transpose node after an input""" + node_name = self.model.create_node_name("Transpose") + if output_name is None: + output_name = node_name + "_out" + "-" + input_name + transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name) + transpose_node.attribute.extend([helper.make_attribute("perm", perm)]) + return transpose_node + + def fuse( + self, + group_norm_node: NodeProto, + input_name_to_nodes: Dict[str, List[NodeProto]], + output_name_to_node: Dict[str, NodeProto], + ): + """ + This optimization will insert an Transpose, and onnxruntime transpose optimizer will remove it together with + another Transpose so that we can get effect of reducing one Transpose after onnxruntime optimization. + Before: + --> Gemm --> Unsqueeze(axes=[2]) --> Unsqueeze(axes=[3]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm + After: + --> Gemm --> Unsqueeze(axes=[1]) --> Unsqueeze(axes=[2]) -->Transpose([0,3,1,2]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm + """ + gemm_path = self.model.match_parent_path( + group_norm_node, ["Transpose", "Add", "Unsqueeze", "Unsqueeze", "Gemm"], [0, 0, None, 0, 0] + ) + if gemm_path is None: + return + transpose, add, unsqueeze_3, unsqueeze_2, gemm = gemm_path + if self.model.find_graph_output(unsqueeze_3.output[0]): + return + + permutation = OnnxModel.get_node_attribute(transpose, "perm") + assert isinstance(permutation, list) + if permutation != [0, 2, 3, 1]: + return + + if not ( + self.model.get_constant_value(unsqueeze_3.input[1]) == 3 + and self.model.get_constant_value(unsqueeze_2.input[1]) == 2 + and len(self.model.get_children(gemm, input_name_to_nodes)) == 1 + and len(self.model.get_children(unsqueeze_3, input_name_to_nodes)) == 1 + and len(self.model.get_children(unsqueeze_2, input_name_to_nodes)) == 1 + ): + return + + # Here we use hard-coded name so that it could be shared for the whole model. + axes_1 = "ort_const_unsqueeze_axes_1" + if self.model.get_initializer(axes_1) is None: + axes_1_tensor = helper.make_tensor( + name=axes_1, + data_type=TensorProto.INT64, + dims=[1], + vals=[1], + ) + self.model.add_initializer(axes_1_tensor, self.this_graph_name) + + axes_2 = "ort_const_unsqueeze_axes_2" + if self.model.get_initializer(axes_2) is None: + axes_2_tensor = helper.make_tensor( + name=axes_2, + data_type=TensorProto.INT64, + dims=[1], + vals=[2], + ) + self.model.add_initializer(axes_2_tensor, self.this_graph_name) + + unsqueeze_3.input[1] = "ort_const_unsqueeze_axes_2" + unsqueeze_2.input[1] = "ort_const_unsqueeze_axes_1" + transpose_output_name = self.model.create_node_name("Transpose") + "_NCHW" + self.model.replace_input_of_all_nodes(unsqueeze_3.output[0], transpose_output_name) + new_transpose = self.create_transpose_node(unsqueeze_3.output[0], [0, 3, 1, 2], transpose_output_name) + self.model.add_node(new_transpose, self.this_graph_name) + self.increase_counter("Insert Transpose") diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 4873489770..53aa184618 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -12,7 +12,7 @@ from fusion_biassplitgelu import FusionBiasSplitGelu from fusion_group_norm import FusionGroupNorm from fusion_nhwc_conv import FusionNhwcConv from fusion_options import FusionOptions -from fusion_transpose import FusionTranspose +from fusion_transpose import FusionInsertTranspose, FusionTranspose from onnx import ModelProto from onnx_model import OnnxModel from onnx_model_bert import BertOnnxModel @@ -131,6 +131,9 @@ class UnetOnnxModel(BertOnnxModel): group_norm_fusion = FusionGroupNorm(self) group_norm_fusion.apply() + insert_transpose_fusion = FusionInsertTranspose(self) + insert_transpose_fusion.apply() + if (options is None) or options.enable_bias_splitgelu: bias_split_gelu_fusion = FusionBiasSplitGelu(self) bias_split_gelu_fusion.apply()