onnxruntime/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py
Tianlei Wu 742658d171
Stable Diffusion CUDA optimizations Part 2 (#14597)
### Description
This is a follow-up of
https://github.com/microsoft/onnxruntime/pull/14428 for Stable Diffusion
CUDA optimizations:
(1) use NchwConv to replace Conv in onnx graph and add Tranpose nodes
accordingly
(2) reduce sequential Transpose nodes to at most one.
(3) symbolic shape infer of NchwConv
(4) fix add bias transpose which causes CUDA error (launching more than
1024 threads per block) in inferencing fp32 model.
(5) add models (bert, bart, stable_diffusion subdirectories) to package;
(6) remove option --disable_channels_last

Note that 
(1) We can add a few graph transformations to reduce Transpose nodes
further. It is not done in this PR due to time limit.
(2) Stable diffusion 2.1 model outputs black images. It seems that
forcing Attention to float32 could avoid the issue. However it is much
slow to use float32 Attention.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-02-07 07:49:15 -08:00

90 lines
3.4 KiB
Python

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import List
from fusion_base import Fusion
from onnx import TensorProto, helper, numpy_helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionNhwcConv(Fusion):
"""Convert Conv to NhwcConv"""
def __init__(self, model: OnnxModel, update_weight=False):
super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv")
self.update_weight = update_weight
def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
"""Append a Transpose node after an input"""
node_name = self.model.create_node_name("Transpose")
if output_name is None:
output_name = node_name + "_out" + "-" + input_name
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
return transpose_node
def fuse(self, conv, input_name_to_nodes, output_name_to_node):
# Add Transpose node to convert input from NCHW to NHWC
input_transpose_node = self.create_transpose_node(conv.input[0], [0, 2, 3, 1])
nhwc_conv_input = input_transpose_node.output[0]
# Create a tensor for transposed weights (already in NHWC format).
node_name = self.model.create_node_name("NhwcConv")
# Make sure the weights is 4D
weight_tensor = self.model.get_initializer(conv.input[1])
if weight_tensor is None:
return
weight = numpy_helper.to_array(weight_tensor)
if len(weight.shape) != 4:
return
if self.update_weight:
# Transpose weights from NCHW to NHWC
weight = weight.transpose(0, 2, 3, 1)
weight_name = node_name + "_weight_NHWC"
nhwc_weight = helper.make_tensor(
name=weight_name,
data_type=TensorProto.FLOAT,
dims=list(weight.shape),
vals=weight.flatten().tolist(),
)
self.model.add_initializer(nhwc_weight, self.this_graph_name)
weight_transpose_node = None
else:
weight_transpose_node = self.create_transpose_node(conv.input[1], [0, 2, 3, 1])
weight_name = weight_transpose_node.output[0]
nhwc_output_name = node_name + "_out" + "-" + conv.output[0]
nhwc_conv = helper.make_node(
"NhwcConv",
inputs=[nhwc_conv_input, weight_name] + conv.input[2:],
outputs=[nhwc_output_name],
name=node_name + "-" + conv.name,
)
nhwc_conv.attribute.extend(conv.attribute)
nhwc_conv.domain = "com.microsoft"
output_transpose_node = self.create_transpose_node(nhwc_conv.output[0], [0, 3, 1, 2], conv.output[0])
self.nodes_to_remove.append(conv)
nodes_to_add = [input_transpose_node, nhwc_conv, output_transpose_node]
if weight_transpose_node:
nodes_to_add.append(weight_transpose_node)
for node in nodes_to_add:
self.node_name_to_graph_name[node.name] = self.this_graph_name
self.nodes_to_add.extend(nodes_to_add)
self.increase_counter("NhwcConv")