mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Stable Diffusion CUDA Optimizations Part 5 (#14706)
Add a fusion to remove transpose in subgraph like ``` --> Gemm --> Unsqueeze(axes=[2]) --> Unsqueeze(axes=[3]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm ``` With this fusion, we can remove 22 Transpose nodes in UNet, and reduce latency by 0.1 second per image in T4.
This commit is contained in:
parent
0f9d2432d2
commit
6f99fb9d4b
2 changed files with 85 additions and 2 deletions
|
|
@ -8,7 +8,7 @@ from typing import Dict, List
|
|||
|
||||
from fusion_base import Fusion
|
||||
from fusion_utils import FusionUtils
|
||||
from onnx import NodeProto, helper
|
||||
from onnx import NodeProto, TensorProto, helper
|
||||
from onnx_model import OnnxModel
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
@ -80,3 +80,83 @@ class FusionTranspose(Fusion):
|
|||
self.nodes_to_remove.append(transpose_a)
|
||||
transpose_b.ClearField("attribute")
|
||||
transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)])
|
||||
|
||||
|
||||
class FusionInsertTranspose(Fusion):
|
||||
def __init__(self, model: OnnxModel):
|
||||
super().__init__(model, "", "GroupNorm")
|
||||
|
||||
def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
|
||||
"""Append a Transpose node after an input"""
|
||||
node_name = self.model.create_node_name("Transpose")
|
||||
if output_name is None:
|
||||
output_name = node_name + "_out" + "-" + input_name
|
||||
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
|
||||
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
|
||||
return transpose_node
|
||||
|
||||
def fuse(
|
||||
self,
|
||||
group_norm_node: NodeProto,
|
||||
input_name_to_nodes: Dict[str, List[NodeProto]],
|
||||
output_name_to_node: Dict[str, NodeProto],
|
||||
):
|
||||
"""
|
||||
This optimization will insert an Transpose, and onnxruntime transpose optimizer will remove it together with
|
||||
another Transpose so that we can get effect of reducing one Transpose after onnxruntime optimization.
|
||||
Before:
|
||||
--> Gemm --> Unsqueeze(axes=[2]) --> Unsqueeze(axes=[3]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
|
||||
After:
|
||||
--> Gemm --> Unsqueeze(axes=[1]) --> Unsqueeze(axes=[2]) -->Transpose([0,3,1,2]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
|
||||
"""
|
||||
gemm_path = self.model.match_parent_path(
|
||||
group_norm_node, ["Transpose", "Add", "Unsqueeze", "Unsqueeze", "Gemm"], [0, 0, None, 0, 0]
|
||||
)
|
||||
if gemm_path is None:
|
||||
return
|
||||
transpose, add, unsqueeze_3, unsqueeze_2, gemm = gemm_path
|
||||
if self.model.find_graph_output(unsqueeze_3.output[0]):
|
||||
return
|
||||
|
||||
permutation = OnnxModel.get_node_attribute(transpose, "perm")
|
||||
assert isinstance(permutation, list)
|
||||
if permutation != [0, 2, 3, 1]:
|
||||
return
|
||||
|
||||
if not (
|
||||
self.model.get_constant_value(unsqueeze_3.input[1]) == 3
|
||||
and self.model.get_constant_value(unsqueeze_2.input[1]) == 2
|
||||
and len(self.model.get_children(gemm, input_name_to_nodes)) == 1
|
||||
and len(self.model.get_children(unsqueeze_3, input_name_to_nodes)) == 1
|
||||
and len(self.model.get_children(unsqueeze_2, input_name_to_nodes)) == 1
|
||||
):
|
||||
return
|
||||
|
||||
# Here we use hard-coded name so that it could be shared for the whole model.
|
||||
axes_1 = "ort_const_unsqueeze_axes_1"
|
||||
if self.model.get_initializer(axes_1) is None:
|
||||
axes_1_tensor = helper.make_tensor(
|
||||
name=axes_1,
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[1],
|
||||
vals=[1],
|
||||
)
|
||||
self.model.add_initializer(axes_1_tensor, self.this_graph_name)
|
||||
|
||||
axes_2 = "ort_const_unsqueeze_axes_2"
|
||||
if self.model.get_initializer(axes_2) is None:
|
||||
axes_2_tensor = helper.make_tensor(
|
||||
name=axes_2,
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[1],
|
||||
vals=[2],
|
||||
)
|
||||
self.model.add_initializer(axes_2_tensor, self.this_graph_name)
|
||||
|
||||
unsqueeze_3.input[1] = "ort_const_unsqueeze_axes_2"
|
||||
unsqueeze_2.input[1] = "ort_const_unsqueeze_axes_1"
|
||||
transpose_output_name = self.model.create_node_name("Transpose") + "_NCHW"
|
||||
self.model.replace_input_of_all_nodes(unsqueeze_3.output[0], transpose_output_name)
|
||||
new_transpose = self.create_transpose_node(unsqueeze_3.output[0], [0, 3, 1, 2], transpose_output_name)
|
||||
self.model.add_node(new_transpose, self.this_graph_name)
|
||||
self.increase_counter("Insert Transpose")
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from fusion_biassplitgelu import FusionBiasSplitGelu
|
|||
from fusion_group_norm import FusionGroupNorm
|
||||
from fusion_nhwc_conv import FusionNhwcConv
|
||||
from fusion_options import FusionOptions
|
||||
from fusion_transpose import FusionTranspose
|
||||
from fusion_transpose import FusionInsertTranspose, FusionTranspose
|
||||
from onnx import ModelProto
|
||||
from onnx_model import OnnxModel
|
||||
from onnx_model_bert import BertOnnxModel
|
||||
|
|
@ -131,6 +131,9 @@ class UnetOnnxModel(BertOnnxModel):
|
|||
group_norm_fusion = FusionGroupNorm(self)
|
||||
group_norm_fusion.apply()
|
||||
|
||||
insert_transpose_fusion = FusionInsertTranspose(self)
|
||||
insert_transpose_fusion.apply()
|
||||
|
||||
if (options is None) or options.enable_bias_splitgelu:
|
||||
bias_split_gelu_fusion = FusionBiasSplitGelu(self)
|
||||
bias_split_gelu_fusion.apply()
|
||||
|
|
|
|||
Loading…
Reference in a new issue