mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
### Description This is a follow-up of https://github.com/microsoft/onnxruntime/pull/14428 for Stable Diffusion CUDA optimizations: (1) use NchwConv to replace Conv in onnx graph and add Tranpose nodes accordingly (2) reduce sequential Transpose nodes to at most one. (3) symbolic shape infer of NchwConv (4) fix add bias transpose which causes CUDA error (launching more than 1024 threads per block) in inferencing fp32 model. (5) add models (bert, bart, stable_diffusion subdirectories) to package; (6) remove option --disable_channels_last Note that (1) We can add a few graph transformations to reduce Transpose nodes further. It is not done in this PR due to time limit. (2) Stable diffusion 2.1 model outputs black images. It seems that forcing Attention to float32 could avoid the issue. However it is much slow to use float32 Attention. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
90 lines
3.4 KiB
Python
90 lines
3.4 KiB
Python
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
|
|
from logging import getLogger
|
|
from typing import List
|
|
|
|
from fusion_base import Fusion
|
|
from onnx import TensorProto, helper, numpy_helper
|
|
from onnx_model import OnnxModel
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class FusionNhwcConv(Fusion):
|
|
"""Convert Conv to NhwcConv"""
|
|
|
|
def __init__(self, model: OnnxModel, update_weight=False):
|
|
super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv")
|
|
self.update_weight = update_weight
|
|
|
|
def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
|
|
"""Append a Transpose node after an input"""
|
|
node_name = self.model.create_node_name("Transpose")
|
|
|
|
if output_name is None:
|
|
output_name = node_name + "_out" + "-" + input_name
|
|
|
|
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
|
|
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
|
|
|
|
return transpose_node
|
|
|
|
def fuse(self, conv, input_name_to_nodes, output_name_to_node):
|
|
# Add Transpose node to convert input from NCHW to NHWC
|
|
input_transpose_node = self.create_transpose_node(conv.input[0], [0, 2, 3, 1])
|
|
|
|
nhwc_conv_input = input_transpose_node.output[0]
|
|
|
|
# Create a tensor for transposed weights (already in NHWC format).
|
|
node_name = self.model.create_node_name("NhwcConv")
|
|
|
|
# Make sure the weights is 4D
|
|
weight_tensor = self.model.get_initializer(conv.input[1])
|
|
if weight_tensor is None:
|
|
return
|
|
weight = numpy_helper.to_array(weight_tensor)
|
|
if len(weight.shape) != 4:
|
|
return
|
|
|
|
if self.update_weight:
|
|
# Transpose weights from NCHW to NHWC
|
|
weight = weight.transpose(0, 2, 3, 1)
|
|
|
|
weight_name = node_name + "_weight_NHWC"
|
|
nhwc_weight = helper.make_tensor(
|
|
name=weight_name,
|
|
data_type=TensorProto.FLOAT,
|
|
dims=list(weight.shape),
|
|
vals=weight.flatten().tolist(),
|
|
)
|
|
self.model.add_initializer(nhwc_weight, self.this_graph_name)
|
|
weight_transpose_node = None
|
|
else:
|
|
weight_transpose_node = self.create_transpose_node(conv.input[1], [0, 2, 3, 1])
|
|
weight_name = weight_transpose_node.output[0]
|
|
|
|
nhwc_output_name = node_name + "_out" + "-" + conv.output[0]
|
|
nhwc_conv = helper.make_node(
|
|
"NhwcConv",
|
|
inputs=[nhwc_conv_input, weight_name] + conv.input[2:],
|
|
outputs=[nhwc_output_name],
|
|
name=node_name + "-" + conv.name,
|
|
)
|
|
nhwc_conv.attribute.extend(conv.attribute)
|
|
nhwc_conv.domain = "com.microsoft"
|
|
|
|
output_transpose_node = self.create_transpose_node(nhwc_conv.output[0], [0, 3, 1, 2], conv.output[0])
|
|
|
|
self.nodes_to_remove.append(conv)
|
|
|
|
nodes_to_add = [input_transpose_node, nhwc_conv, output_transpose_node]
|
|
if weight_transpose_node:
|
|
nodes_to_add.append(weight_transpose_node)
|
|
for node in nodes_to_add:
|
|
self.node_name_to_graph_name[node.name] = self.this_graph_name
|
|
self.nodes_to_add.extend(nodes_to_add)
|
|
|
|
self.increase_counter("NhwcConv")
|