Stable Diffusion CUDA optimizations Part 2 (#14597)

### Description
This is a follow-up of
https://github.com/microsoft/onnxruntime/pull/14428 for Stable Diffusion
CUDA optimizations:
(1) use NchwConv to replace Conv in onnx graph and add Tranpose nodes
accordingly
(2) reduce sequential Transpose nodes to at most one.
(3) symbolic shape infer of NchwConv
(4) fix add bias transpose which causes CUDA error (launching more than
1024 threads per block) in inferencing fp32 model.
(5) add models (bert, bart, stable_diffusion subdirectories) to package;
(6) remove option --disable_channels_last

Note that 
(1) We can add a few graph transformations to reduce Transpose nodes
further. It is not done in this PR due to time limit.
(2) Stable diffusion 2.1 model outputs black images. It seems that
forcing Attention to float32 could avoid the issue. However it is much
slow to use float32 Attention.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Tianlei Wu 2023-02-07 07:49:15 -08:00 committed by GitHub
parent f88a4646cd
commit 742658d171
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 386 additions and 61 deletions

View file

@ -467,12 +467,21 @@ file(GLOB onnxruntime_python_quantization_cal_table_flatbuffers_src CONFIGURE_DE
file(GLOB onnxruntime_python_transformers_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/*.py"
)
file(GLOB onnxruntime_python_transformers_models_bart_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/bart/*.py"
)
file(GLOB onnxruntime_python_transformers_models_bert_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/bert/*.py"
)
file(GLOB onnxruntime_python_transformers_models_gpt2_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/gpt2/*.py"
)
file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py"
)
file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py"
)
file(GLOB onnxruntime_python_transformers_models_t5_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/t5/*.py"
)
@ -526,8 +535,11 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/tools/ort_format_model/ort_flatbuffers_py
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bart
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bert
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/gpt2
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/operators
@ -606,12 +618,21 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_bart_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bart/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_bert_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bert/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_gpt2_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/gpt2/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_longformer_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_stable_diffusion_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_t5_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5/

View file

@ -519,6 +519,7 @@ void InvokeAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, T* qkv_add_bias, const int v_head_size, int total_matrix_count) {
assert(num_heads <= max_threads_per_block);
const dim3 grid(sequence_length, batch_size, num_matrices);
if (qk_head_size * num_heads <= max_threads_per_block) {
const dim3 block(qk_head_size, num_heads, 1);
@ -544,7 +545,7 @@ void InvokeAddBiasTranspose(
AddBiasTranspose<T><<<grid, block, 0, stream>>>(input, biases, output);
}
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
if (format == 2) {
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
} else if (format == 1) {
@ -577,7 +578,7 @@ void LaunchAddBiasTranspose(
const half* input, const half* biases, half* output,
bool enable_half4, const int v_head_size, half* qkv_add_bias, int total_matrix_count) {
total_matrix_count = std::max(num_matrices, total_matrix_count);
if (enable_half4 && 0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) {
if (enable_half4 && 0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4))) {
const int H = qk_head_size / 4;
const int H_v = v_head_size / 4;
const Half4* input2 = reinterpret_cast<const Half4*>(input);
@ -587,7 +588,7 @@ void LaunchAddBiasTranspose(
InvokeAddBiasTranspose<Half4>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2,
qkv_add_bias2, H_v, total_matrix_count);
} else if (0 == (qk_head_size & 1) && 0 == (v_head_size & 1)) {
} else if (0 == (qk_head_size & 1) && (v_head_size == -1 || 0 == (v_head_size & 1))) {
const int H = qk_head_size / 2;
const int H_v = v_head_size / 2;
const half2* input2 = reinterpret_cast<const half2*>(input);
@ -612,7 +613,7 @@ void LaunchAddBiasTranspose(
const float* input, const float* biases, float* output,
bool /*enable_half4*/, const int v_head_size, float* qkv_add_bias, int total_matrix_count) {
total_matrix_count = std::max(num_matrices, total_matrix_count);
if (0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) {
if (0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4))) {
const int H = qk_head_size / 4;
const float4* input2 = reinterpret_cast<const float4*>(input);
const float4* biases2 = reinterpret_cast<const float4*>(biases);
@ -622,7 +623,7 @@ void LaunchAddBiasTranspose(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2,
qkv_add_bias2, v_head_size / 4, total_matrix_count);
} else if (0 == (qk_head_size & 1) && 0 == (v_head_size & 1)) {
} else if (0 == (qk_head_size & 1) && (v_head_size == -1 || 0 == (v_head_size & 1))) {
const int H = qk_head_size / 2;
const float2* input2 = reinterpret_cast<const float2*>(input);
const float2* biases2 = reinterpret_cast<const float2*>(biases);
@ -654,7 +655,7 @@ void InvokeAddBiasTransposeTrt(
const dim3 block(head_size, num_heads, 1);
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(query, key, value, biases, output);
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, query, key, value, biases, output);
}
} else { // cross attention
@ -666,7 +667,7 @@ void InvokeAddBiasTransposeTrt(
const dim3 block(head_size, num_heads, 1);
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(query, biases, output);
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, query, biases, output);
}
}
@ -680,7 +681,7 @@ void InvokeAddBiasTransposeTrt(
const dim3 block(head_size, num_heads, 1);
AddBiasTransposeTrtKV<T><<<grid, block, 0, stream>>>(key, value, biases, packed_kv);
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
AddBiasTransposeTrtKVLarge<T><<<grid, block, 0, stream>>>(head_size, key, value, biases, packed_kv);
}
}
@ -737,6 +738,7 @@ void InvokeAddBias(
const int batch_size, const int sequence_length, const int kv_sequence_length,
const int num_heads, const int head_size, const int v_head_size,
const T* biases, const T* query, const T* key, const T* value, T* q, T* k, T* v) {
assert(num_heads <= max_threads_per_block);
constexpr int num_matrices = 1;
// Q
{
@ -745,7 +747,7 @@ void InvokeAddBias(
const dim3 block(head_size, num_heads, 1);
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(query, biases, q);
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, query, biases, q);
}
}
@ -758,7 +760,7 @@ void InvokeAddBias(
const dim3 block(head_size, num_heads, 1);
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(key, biases_k, k);
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, key, biases_k, k);
}
}
@ -772,7 +774,7 @@ void InvokeAddBias(
const dim3 block(v_head_size, num_heads, 1);
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(value, biases_v, v);
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(v_head_size, value, biases_v, v);
}
}

View file

@ -203,6 +203,7 @@ class SymbolicShapeInference:
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
"GroupNorm": self._infer_GroupNorm,
"BiasSplitGelu": self._infer_BiasSplitGelu,
"NhwcConv": self._infer_NhwcConv,
}
self.aten_op_dispatcher_ = {
"embedding": self._infer_Gather,
@ -442,6 +443,7 @@ class SymbolicShapeInference:
"MultiHeadAttention",
"GroupNorm",
"BiasSplitGelu",
"NhwcConv",
]
if not skip_infer:
@ -623,13 +625,13 @@ class SymbolicShapeInference:
def _new_symbolic_shape(self, rank, node, out_idx=0):
return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)]
def _compute_conv_pool_shape(self, node):
def _compute_conv_pool_shape(self, node, channels_last=False):
sympy_shape = self._get_sympy_shape(node, 0)
if len(node.input) > 1:
W_shape = self._get_sympy_shape(node, 1)
rank = len(W_shape) - 2 # number of spatial axes
kernel_shape = W_shape[-rank:]
sympy_shape[1] = W_shape[0]
kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:]
sympy_shape[3 if channels_last else 1] = W_shape[0]
else:
W_shape = None
kernel_shape = get_attribute(node, "kernel_shape")
@ -638,13 +640,17 @@ class SymbolicShapeInference:
assert len(sympy_shape) == rank + 2
# only need to symbolic shape inference if input has symbolic dims in spatial axes
is_symbolic_dims = [not is_literal(i) for i in sympy_shape[-rank:]]
spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:]
is_symbolic_dims = [not is_literal(i) for i in spatial_shape]
if not any(is_symbolic_dims):
shape = get_shape_from_value_info(self.known_vi_[node.output[0]])
if len(shape) > 0:
assert len(sympy_shape) == len(shape)
sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
if channels_last:
sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]]
else:
sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
return sympy_shape
dilations = get_attribute(node, "dilations", [1] * rank)
@ -675,7 +681,7 @@ class SymbolicShapeInference:
ceil_mode = get_attribute(node, "ceil_mode", 0)
for i in range(rank):
effective_input_size = sympy_shape[-rank + i]
effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)]
if len(total_pads) > 0:
effective_input_size = effective_input_size + total_pads[i]
if ceil_mode:
@ -684,7 +690,7 @@ class SymbolicShapeInference:
)
else:
strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i]
sympy_shape[-rank + i] = strided_kernel_positions + 1
sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1
return sympy_shape
def _check_merged_dims(self, dims, allow_broadcast=True):
@ -918,6 +924,18 @@ class SymbolicShapeInference:
)
)
def _infer_NhwcConv(self, node):
sympy_shape = self._compute_conv_pool_shape(node, channels_last=True)
self._update_computed_dims(sympy_shape)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0],
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(sympy_shape),
)
)
def _infer_Einsum(self, node):
# ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275
equation = get_attribute(node, "equation")
@ -2459,6 +2477,7 @@ class SymbolicShapeInference:
all_shapes_inferred = symbolic_shape_inference._infer_impl()
symbolic_shape_inference._update_output_from_vi()
if not all_shapes_inferred:
onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True)
raise Exception("Incomplete symbolic shape inference")
return symbolic_shape_inference.out_mp_

View file

@ -0,0 +1,90 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import List
from fusion_base import Fusion
from onnx import TensorProto, helper, numpy_helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionNhwcConv(Fusion):
"""Convert Conv to NhwcConv"""
def __init__(self, model: OnnxModel, update_weight=False):
super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv")
self.update_weight = update_weight
def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
"""Append a Transpose node after an input"""
node_name = self.model.create_node_name("Transpose")
if output_name is None:
output_name = node_name + "_out" + "-" + input_name
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
return transpose_node
def fuse(self, conv, input_name_to_nodes, output_name_to_node):
# Add Transpose node to convert input from NCHW to NHWC
input_transpose_node = self.create_transpose_node(conv.input[0], [0, 2, 3, 1])
nhwc_conv_input = input_transpose_node.output[0]
# Create a tensor for transposed weights (already in NHWC format).
node_name = self.model.create_node_name("NhwcConv")
# Make sure the weights is 4D
weight_tensor = self.model.get_initializer(conv.input[1])
if weight_tensor is None:
return
weight = numpy_helper.to_array(weight_tensor)
if len(weight.shape) != 4:
return
if self.update_weight:
# Transpose weights from NCHW to NHWC
weight = weight.transpose(0, 2, 3, 1)
weight_name = node_name + "_weight_NHWC"
nhwc_weight = helper.make_tensor(
name=weight_name,
data_type=TensorProto.FLOAT,
dims=list(weight.shape),
vals=weight.flatten().tolist(),
)
self.model.add_initializer(nhwc_weight, self.this_graph_name)
weight_transpose_node = None
else:
weight_transpose_node = self.create_transpose_node(conv.input[1], [0, 2, 3, 1])
weight_name = weight_transpose_node.output[0]
nhwc_output_name = node_name + "_out" + "-" + conv.output[0]
nhwc_conv = helper.make_node(
"NhwcConv",
inputs=[nhwc_conv_input, weight_name] + conv.input[2:],
outputs=[nhwc_output_name],
name=node_name + "-" + conv.name,
)
nhwc_conv.attribute.extend(conv.attribute)
nhwc_conv.domain = "com.microsoft"
output_transpose_node = self.create_transpose_node(nhwc_conv.output[0], [0, 3, 1, 2], conv.output[0])
self.nodes_to_remove.append(conv)
nodes_to_add = [input_transpose_node, nhwc_conv, output_transpose_node]
if weight_transpose_node:
nodes_to_add.append(weight_transpose_node)
for node in nodes_to_add:
self.node_name_to_graph_name[node.name] = self.this_graph_name
self.nodes_to_add.extend(nodes_to_add)
self.increase_counter("NhwcConv")

View file

@ -119,16 +119,15 @@ class FusionReshape(Fusion):
shape_nodes.extend([path2[-1], path3[-1]])
shape.append(-1)
elif len(concat_node.input) > 2:
concat_2 = self.model.get_initializer(concat_node.input[2])
if concat_2 is None:
concat_value = self.model.get_constant_value(concat_node.input[2])
if concat_value is None:
return
concat_value = numpy_helper.to_array(concat_2)
if isinstance(concat_value, np.ndarray):
shape.extend(concat_value.tolist())
else:
shape.append(concat_value)
if len(concat_node.input) == 4 and self.model.get_initializer(concat_node.input[3]) is None:
if len(concat_node.input) == 4 and self.model.get_constant_value(concat_node.input[3]) is None:
if -1 in shape:
return

View file

@ -0,0 +1,81 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Dict, List
from fusion_base import Fusion
from fusion_utils import FusionUtils
from onnx import NodeProto, helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionTranspose(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "Transpose", "Transpose")
def fuse(
self,
transpose_node: NodeProto,
input_name_to_nodes: Dict[str, List[NodeProto]],
output_name_to_node: Dict[str, NodeProto],
):
"""
Case 1:
(input)-->Transpose(perm=a)-->Transpose(perm=b)-->
After:
(input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
|
+----->Transpose(perm=a*b)-->
Case 2 (Cast has only one child):
(input)-->Transpose(perm=a)--> Cast -->Transpose(perm=b)-->
After:
(input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
|
+----->Cast --> Transpose(perm=a*b)-->
"""
transpose_b = transpose_node
if transpose_b.input[0] not in output_name_to_node:
return
transpose_a = output_name_to_node[transpose_b.input[0]]
if transpose_a.op_type != "Cast":
cast_node = None
else:
cast_node = transpose_a
cast_children = self.model.get_children(cast_node, input_name_to_nodes)
if cast_children and len(cast_children) > 1:
return
transpose_a = output_name_to_node[cast_node.input[0]]
if transpose_a.op_type != "Transpose":
return
permutation = OnnxModel.get_node_attribute(transpose_b, "perm")
assert isinstance(permutation, list)
parent_permutation = OnnxModel.get_node_attribute(transpose_a, "perm")
assert isinstance(parent_permutation, list)
assert len(parent_permutation) == len(permutation)
output_permutation = []
for j, index in enumerate(permutation):
output_permutation.append(parent_permutation[index])
if cast_node is None:
if FusionUtils.skip_parent(self.model, transpose_b, transpose_a, input_name_to_nodes):
self.nodes_to_remove.append(transpose_a)
else:
if FusionUtils.skip_parent(self.model, cast_node, transpose_a, input_name_to_nodes):
self.nodes_to_remove.append(transpose_a)
transpose_b.ClearField("attribute")
transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)])

View file

@ -73,6 +73,32 @@ class FusionUtils:
self.model.remove_node(node)
self.model.replace_input_of_all_nodes(output_name, input_name)
@staticmethod
def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes):
"""
Before:
(input)-->parent-->node-->(output)
After:
(input)-->parent-->
|
+----->node-->(output)
This function returns a flag about whether the parent node can be removed.
Note that this function assumes the node has first input links from parent!
"""
parent_can_be_removed = False
input_name_to_nodes[node.input[0]].remove(node)
# We can remove the first Transpose if its output is not used (linked to graph output or other nodes) anymore.
if len(input_name_to_nodes[node.input[0]]) == 0 and not model.find_graph_output(
node.input[0]
): # checks main graph output. TODO: deal with subgraph
parent_can_be_removed = True
# self.nodes_to_remove.append(transpose_a)
input_name_to_nodes[parent_node.input[0]].append(node)
node.input[0] = parent_node.input[0]
return parent_can_be_removed
@staticmethod
def check_node_attribute(node, attribute_name: str, expected_value, default_value=None):
"""Verify that a node has expected value for an attribute.
@ -228,7 +254,10 @@ class FusionUtils:
graph_output_names = set(self.model.get_graphs_output_names())
for node in nodes_to_remove:
if bool(set(node.output) & graph_output_names):
if not bool(set(node.input) & graph_input_names):
if (
not bool(set(node.input) & graph_input_names)
and len(self.model.input_name_to_nodes()[node.input[0]]) == 1 # parent has only one child
):
self.model.replace_output_of_all_nodes(node.input[0], node.output[0])
else:
continue

View file

@ -62,7 +62,7 @@ def get_ort_pipeline(model_name: str, directory: str, provider: str, disable_saf
return pipe
def get_torch_pipeline(model_name: str, disable_channels_last: bool, disable_safety_checker: bool):
def get_torch_pipeline(model_name: str, disable_safety_checker: bool):
from diffusers import StableDiffusionPipeline
from torch import channels_last, float16
@ -70,8 +70,7 @@ def get_torch_pipeline(model_name: str, disable_channels_last: bool, disable_saf
model_name, torch_dtype=float16, revision="fp16", use_auth_token=True
).to("cuda")
if not disable_channels_last:
pipe.unet.to(memory_format=channels_last) # in-place operation
pipe.unet.to(memory_format=channels_last) # in-place operation
if disable_safety_checker:
pipe.safety_checker = None
@ -144,7 +143,7 @@ def run_ort(model_name: str, directory: str, provider: str, batch_size: int, dis
run_ort_pipeline(pipe, batch_size, image_filename_prefix)
def run_torch(model_name: str, batch_size: int, disable_channels_last: bool, disable_safety_checker: bool):
def run_torch(model_name: str, batch_size: int, disable_safety_checker: bool):
import torch
torch.backends.cudnn.enabled = True
@ -154,13 +153,11 @@ def run_torch(model_name: str, batch_size: int, disable_channels_last: bool, dis
torch.set_grad_enabled(False)
load_start = time.time()
pipe = get_torch_pipeline(model_name, disable_channels_last, disable_safety_checker)
pipe = get_torch_pipeline(model_name, disable_safety_checker)
load_end = time.time()
print(f"Model loading took {load_end - load_start} seconds")
image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, disable_safety_checker) + (
"" if disable_channels_last else "_channels_last"
)
image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, disable_safety_checker)
with torch.inference_mode():
run_torch_pipeline(pipe, batch_size, image_filename_prefix)
@ -196,15 +193,6 @@ def parse_arguments():
help="Directory of saved onnx pipeline. It could be output directory of optimize_pipeline.py.",
)
parser.add_argument(
"-c",
"--disable_channels_last",
required=False,
action="store_true",
help="Disable channels last for torch. It will be ignored for onnxruntime engine",
)
parser.set_defaults(disable_channels_last=False)
parser.add_argument(
"--enable_safety_checker",
required=False,
@ -237,7 +225,7 @@ def main():
provider = "CUDAExecutionProvider" # TODO: use ["CUDAExecutionProvider", "CPUExecutionProvider"] in diffuers
run_ort(sd_model, args.pipeline, provider, args.batch_size, not args.enable_safety_checker)
else:
run_torch(sd_model, args.batch_size, args.disable_channels_last, not args.enable_safety_checker)
run_torch(sd_model, args.batch_size, not args.enable_safety_checker)
if __name__ == "__main__":

View file

@ -11,18 +11,15 @@
# huggingface-cli login
# wget https://raw.githubusercontent.com/huggingface/diffusers/v0.12.1/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path $ONNX_ROOT/stable-diffusion-v1-5-fp32
# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path stabilityai/stable-diffusion-2-1 --output_path $ONNX_ROOT/stable-diffusion-v2-1-fp32
# Note that this script might not be compatible with older or newer version of diffusers/transformers. It is because fusion script need change accordingly when onnx graph is changed.
# Note that this script might not be compatible with older or newer version of diffusers.
# Then you can use this script to convert them to float16 like the following:
# python optimize_pipeline.py -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16
# python optimize_pipeline.py -i $ONNX_ROOT/stable-diffusion-v2-1-fp32 -o $ONNX_ROOT/stable-diffusion-v2-1-fp16 --float16
# Or
# pip install -U onnxruntime-gpu >= 1.14
# python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16
# python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i $ONNX_ROOT/stable-diffusion-v2-1-fp32 -o $ONNX_ROOT/stable-diffusion-v2-1-fp16 --float16
# Note that float16 model is for CUDA Execution Provider. It might not run in CPU Execution Provider.
#
# Note that output model is for CUDA Execution Provider. It might not run in CPU Execution Provider.
# Stable diffusion 2.1 model will get black images using float16 Attention. It is a known issue that we are working on.
import argparse
import logging
@ -40,7 +37,7 @@ from optimizer import optimize_model # noqa: E402
logger = logging.getLogger(__name__)
def optimize_stable_diffusion_onnx_pipeline(
def optimize_sd_pipeline(
source_dir: Path, target_dir: Path, overwrite: bool, use_external_data_format: bool, float16: bool
):
"""Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16.
@ -66,23 +63,18 @@ def optimize_stable_diffusion_onnx_pipeline(
raise RuntimeError(message)
continue
num_heads = 0
hidden_size = 0
# Graph fusion before fp16 conversion, otherwise they cannot be fused later.
# Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead.
logger.info(f"optimize {onnx_model_path}...")
fusion_options = FusionOptions("unet")
# packed kv requires compute capacity >= 7.5 (like T4, A100, RTX 2060~4090. See https://developer.nvidia.com/cuda-gpus)
# Suggest to disable it if you are using older GPU like V100, RTX 1060/1070/1080, or using float32 model.
fusion_options.enable_packed_kv = float16
m = optimize_model(
str(onnx_model_path),
model_type="unet",
num_heads=num_heads,
hidden_size=hidden_size,
num_heads=0, # will be deduced from graph
hidden_size=0, # will be deduced from graph
opt_level=0,
optimization_options=fusion_options,
use_gpu=False,
@ -211,7 +203,7 @@ def main():
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
args = parse_arguments()
copy_extra_directory(Path(args.input), Path(args.output), args.overwrite)
optimize_stable_diffusion_onnx_pipeline(
optimize_sd_pipeline(
Path(args.input), Path(args.output), args.overwrite, args.use_external_data_format, args.float16
)

View file

@ -0,0 +1,14 @@
# Install the following package in python 3.10
diffusers==0.12.1
transformers==4.26.0
numpy==1.24.1
accelerate==0.15.0
onnxruntime-gpu>=1.14
onnx==1.13.0
coloredlogs
packaging==23.0
protobuf==3.20.3
psutil==5.9.4
sympy==1.11.1
--extra-index-url https://download.pytorch.org/whl/cu117
torch==1.13.1+cu117

View file

@ -128,6 +128,8 @@ class OnnxModel:
for graph in self.graphs():
if node in graph.node:
graph.node.remove(node)
return
logger.warning("Failed to remove node %s", node) # It might be a bug to hit this line.
def remove_nodes(self, nodes_to_remove):
for node in nodes_to_remove:
@ -182,6 +184,12 @@ class OnnxModel:
node.output[j] = new_output_name
def replace_output_of_all_nodes(self, old_output_name, new_output_name):
# This function shall be used carefully. For example:
# Add --[old_name]--> Cast ---> [new_name]
# |
# +----[old_name]--> Transpose -->
# If we want to remove the Cast node: replace output of Add to new_name is not enough;
# The input of Transpose shall also be updated to new_name.
for node in self.model.graph.node:
OnnxModel.replace_node_output(node, old_output_name, new_output_name)
@ -553,7 +561,9 @@ class OnnxModel:
graph_output_names = set(self.get_graphs_output_names())
for node in nodes_to_remove:
if bool(set(node.output) & graph_output_names):
if not bool(set(node.input) & graph_input_names):
if (not bool(set(node.input) & graph_input_names)) and len(
self.input_name_to_nodes()[node.input[0]]
) == 1:
self.replace_output_of_all_nodes(node.input[0], node.output[0])
else:
continue

View file

@ -9,8 +9,11 @@ from typing import Optional
from fusion_attention_unet import FusionAttentionUnet
from fusion_biassplitgelu import FusionBiasSplitGelu
from fusion_group_norm import FusionGroupNorm
from fusion_nhwc_conv import FusionNhwcConv
from fusion_options import FusionOptions
from fusion_transpose import FusionTranspose
from onnx import ModelProto
from onnx_model import OnnxModel
from onnx_model_bert import BertOnnxModel
logger = getLogger(__name__)
@ -30,10 +33,61 @@ class UnetOnnxModel(BertOnnxModel):
super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
def preprocess(self):
return
self.remove_useless_div()
def postprocess(self):
self.merge_sequential_transpose()
self.prune_graph()
self.remove_unused_constant()
def remove_useless_div(self):
"""Remove Div by 1"""
div_nodes = [node for node in self.nodes() if node.op_type == "Div"]
nodes_to_remove = []
for div in div_nodes:
if self.find_constant_input(div, 1.0) == 1:
nodes_to_remove.append(div)
for node in nodes_to_remove:
self.replace_input_of_all_nodes(node.output[0], node.input[0])
if nodes_to_remove:
self.remove_nodes(nodes_to_remove)
logger.info("Removed %d useless Div (by 1) nodes", len(nodes_to_remove))
def convert_conv_to_nhwc(self):
# Do not update weight here since save external data has a bug
conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=False)
conv_to_nhwc_conv.apply()
def merge_sequential_transpose(self):
fusion_transpose = FusionTranspose(self)
fusion_transpose.apply()
remove_count = 0
nodes = self.get_nodes_by_op_type("Transpose")
for node in nodes:
permutation = OnnxModel.get_node_attribute(node, "perm")
assert isinstance(permutation, list)
if permutation != list(range(len(permutation))):
continue
assert not (
self.find_graph_output(node.output[0])
or self.find_graph_input(node.input[0])
or self.find_graph_output(node.input[0])
)
# Let all children nodes skip current Transpose node and link to its parent
# Note that we cannot update parent node output since parent node might have more than one children.
self.replace_input_of_all_nodes(node.output[0], node.input[0])
self.remove_node(node)
remove_count += 1
total = len(fusion_transpose.nodes_to_remove) + remove_count
if total:
logger.info("Removed %d Transpose nodes", total)
def optimize(self, options: Optional[FusionOptions] = None):
if (options is not None) and not options.enable_shape_inference:
@ -78,7 +132,7 @@ class UnetOnnxModel(BertOnnxModel):
# Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
self.utils.remove_useless_reshape_nodes()
self.postprocess()
self.convert_conv_to_nhwc()
if (options is None) or options.enable_bias_skip_layer_norm:
# Fuse SkipLayerNormalization and Add Bias before it.
@ -87,6 +141,29 @@ class UnetOnnxModel(BertOnnxModel):
if options is not None and options.enable_gelu_approximation:
self.gelu_approximation()
self.remove_unused_constant()
self.postprocess()
logger.info(f"opset version: {self.get_opset_version()}")
def get_fused_operator_statistics(self):
"""
Returns node count of fused operators.
"""
op_count = {}
ops = [
"Attention",
"MultiHeadAttention",
"Gelu",
"FastGelu",
"LayerNormalization",
"SkipLayerNormalization",
"BiasSplitGelu",
"GroupNorm",
"NhwcConv",
]
for op in ops:
nodes = self.get_nodes_by_op_type(op)
op_count[op] = len(nodes)
logger.info(f"Optimized operators:{op_count}")
return op_count

View file

@ -481,9 +481,12 @@ packages = [
"onnxruntime.quantization.operators",
"onnxruntime.quantization.CalTableFlatBuffers",
"onnxruntime.transformers",
"onnxruntime.transformers.models.bart",
"onnxruntime.transformers.models.bert",
"onnxruntime.transformers.models.gpt2",
"onnxruntime.transformers.models.longformer",
"onnxruntime.transformers.models.t5",
"onnxruntime.transformers.models.stable_diffusion",
]
package_data = {"onnxruntime.tools.mobile_helpers": ["*.md", "*.config"]}