Stable Diffusion CUDA Optimizations Part 5 (#14706)

Add a fusion to remove transpose in subgraph like  
```
--> Gemm --> Unsqueeze(axes=[2]) --> Unsqueeze(axes=[3]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
```
With this fusion, we can remove 22 Transpose nodes in UNet, and reduce
latency by 0.1 second per image in T4.
This commit is contained in:
Tianlei Wu 2023-02-16 01:10:00 -08:00 committed by GitHub
parent 0f9d2432d2
commit 6f99fb9d4b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 85 additions and 2 deletions

View file

@ -8,7 +8,7 @@ from typing import Dict, List
from fusion_base import Fusion
from fusion_utils import FusionUtils
from onnx import NodeProto, helper
from onnx import NodeProto, TensorProto, helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
@ -80,3 +80,83 @@ class FusionTranspose(Fusion):
self.nodes_to_remove.append(transpose_a)
transpose_b.ClearField("attribute")
transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)])
class FusionInsertTranspose(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "", "GroupNorm")
def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
"""Append a Transpose node after an input"""
node_name = self.model.create_node_name("Transpose")
if output_name is None:
output_name = node_name + "_out" + "-" + input_name
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
return transpose_node
def fuse(
self,
group_norm_node: NodeProto,
input_name_to_nodes: Dict[str, List[NodeProto]],
output_name_to_node: Dict[str, NodeProto],
):
"""
This optimization will insert an Transpose, and onnxruntime transpose optimizer will remove it together with
another Transpose so that we can get effect of reducing one Transpose after onnxruntime optimization.
Before:
--> Gemm --> Unsqueeze(axes=[2]) --> Unsqueeze(axes=[3]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
After:
--> Gemm --> Unsqueeze(axes=[1]) --> Unsqueeze(axes=[2]) -->Transpose([0,3,1,2]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
"""
gemm_path = self.model.match_parent_path(
group_norm_node, ["Transpose", "Add", "Unsqueeze", "Unsqueeze", "Gemm"], [0, 0, None, 0, 0]
)
if gemm_path is None:
return
transpose, add, unsqueeze_3, unsqueeze_2, gemm = gemm_path
if self.model.find_graph_output(unsqueeze_3.output[0]):
return
permutation = OnnxModel.get_node_attribute(transpose, "perm")
assert isinstance(permutation, list)
if permutation != [0, 2, 3, 1]:
return
if not (
self.model.get_constant_value(unsqueeze_3.input[1]) == 3
and self.model.get_constant_value(unsqueeze_2.input[1]) == 2
and len(self.model.get_children(gemm, input_name_to_nodes)) == 1
and len(self.model.get_children(unsqueeze_3, input_name_to_nodes)) == 1
and len(self.model.get_children(unsqueeze_2, input_name_to_nodes)) == 1
):
return
# Here we use hard-coded name so that it could be shared for the whole model.
axes_1 = "ort_const_unsqueeze_axes_1"
if self.model.get_initializer(axes_1) is None:
axes_1_tensor = helper.make_tensor(
name=axes_1,
data_type=TensorProto.INT64,
dims=[1],
vals=[1],
)
self.model.add_initializer(axes_1_tensor, self.this_graph_name)
axes_2 = "ort_const_unsqueeze_axes_2"
if self.model.get_initializer(axes_2) is None:
axes_2_tensor = helper.make_tensor(
name=axes_2,
data_type=TensorProto.INT64,
dims=[1],
vals=[2],
)
self.model.add_initializer(axes_2_tensor, self.this_graph_name)
unsqueeze_3.input[1] = "ort_const_unsqueeze_axes_2"
unsqueeze_2.input[1] = "ort_const_unsqueeze_axes_1"
transpose_output_name = self.model.create_node_name("Transpose") + "_NCHW"
self.model.replace_input_of_all_nodes(unsqueeze_3.output[0], transpose_output_name)
new_transpose = self.create_transpose_node(unsqueeze_3.output[0], [0, 3, 1, 2], transpose_output_name)
self.model.add_node(new_transpose, self.this_graph_name)
self.increase_counter("Insert Transpose")

View file

@ -12,7 +12,7 @@ from fusion_biassplitgelu import FusionBiasSplitGelu
from fusion_group_norm import FusionGroupNorm
from fusion_nhwc_conv import FusionNhwcConv
from fusion_options import FusionOptions
from fusion_transpose import FusionTranspose
from fusion_transpose import FusionInsertTranspose, FusionTranspose
from onnx import ModelProto
from onnx_model import OnnxModel
from onnx_model_bert import BertOnnxModel
@ -131,6 +131,9 @@ class UnetOnnxModel(BertOnnxModel):
group_norm_fusion = FusionGroupNorm(self)
group_norm_fusion.apply()
insert_transpose_fusion = FusionInsertTranspose(self)
insert_transpose_fusion.apply()
if (options is None) or options.enable_bias_splitgelu:
bias_split_gelu_fusion = FusionBiasSplitGelu(self)
bias_split_gelu_fusion.apply()