mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Stable Diffusion CUDA optimizations Part 2 (#14597)
### Description This is a follow-up of https://github.com/microsoft/onnxruntime/pull/14428 for Stable Diffusion CUDA optimizations: (1) use NchwConv to replace Conv in onnx graph and add Tranpose nodes accordingly (2) reduce sequential Transpose nodes to at most one. (3) symbolic shape infer of NchwConv (4) fix add bias transpose which causes CUDA error (launching more than 1024 threads per block) in inferencing fp32 model. (5) add models (bert, bart, stable_diffusion subdirectories) to package; (6) remove option --disable_channels_last Note that (1) We can add a few graph transformations to reduce Transpose nodes further. It is not done in this PR due to time limit. (2) Stable diffusion 2.1 model outputs black images. It seems that forcing Attention to float32 could avoid the issue. However it is much slow to use float32 Attention. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
f88a4646cd
commit
742658d171
13 changed files with 386 additions and 61 deletions
|
|
@ -467,12 +467,21 @@ file(GLOB onnxruntime_python_quantization_cal_table_flatbuffers_src CONFIGURE_DE
|
|||
file(GLOB onnxruntime_python_transformers_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_transformers_models_bart_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/bart/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_transformers_models_bert_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/bert/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_transformers_models_gpt2_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/gpt2/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_transformers_models_t5_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/t5/*.py"
|
||||
)
|
||||
|
|
@ -526,8 +535,11 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/tools/ort_format_model/ort_flatbuffers_py
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bart
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bert
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/gpt2
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/operators
|
||||
|
|
@ -606,12 +618,21 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_bart_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bart/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_bert_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bert/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_gpt2_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/gpt2/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_longformer_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_stable_diffusion_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_t5_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5/
|
||||
|
|
|
|||
|
|
@ -519,6 +519,7 @@ void InvokeAddBiasTranspose(
|
|||
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
|
||||
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
|
||||
const T* input, const T* biases, T* output, T* qkv_add_bias, const int v_head_size, int total_matrix_count) {
|
||||
assert(num_heads <= max_threads_per_block);
|
||||
const dim3 grid(sequence_length, batch_size, num_matrices);
|
||||
if (qk_head_size * num_heads <= max_threads_per_block) {
|
||||
const dim3 block(qk_head_size, num_heads, 1);
|
||||
|
|
@ -544,7 +545,7 @@ void InvokeAddBiasTranspose(
|
|||
AddBiasTranspose<T><<<grid, block, 0, stream>>>(input, biases, output);
|
||||
}
|
||||
} else {
|
||||
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
|
||||
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
|
||||
if (format == 2) {
|
||||
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
|
||||
} else if (format == 1) {
|
||||
|
|
@ -577,7 +578,7 @@ void LaunchAddBiasTranspose(
|
|||
const half* input, const half* biases, half* output,
|
||||
bool enable_half4, const int v_head_size, half* qkv_add_bias, int total_matrix_count) {
|
||||
total_matrix_count = std::max(num_matrices, total_matrix_count);
|
||||
if (enable_half4 && 0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) {
|
||||
if (enable_half4 && 0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4))) {
|
||||
const int H = qk_head_size / 4;
|
||||
const int H_v = v_head_size / 4;
|
||||
const Half4* input2 = reinterpret_cast<const Half4*>(input);
|
||||
|
|
@ -587,7 +588,7 @@ void LaunchAddBiasTranspose(
|
|||
InvokeAddBiasTranspose<Half4>(stream, num_matrices, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, H, input2, biases2, output2,
|
||||
qkv_add_bias2, H_v, total_matrix_count);
|
||||
} else if (0 == (qk_head_size & 1) && 0 == (v_head_size & 1)) {
|
||||
} else if (0 == (qk_head_size & 1) && (v_head_size == -1 || 0 == (v_head_size & 1))) {
|
||||
const int H = qk_head_size / 2;
|
||||
const int H_v = v_head_size / 2;
|
||||
const half2* input2 = reinterpret_cast<const half2*>(input);
|
||||
|
|
@ -612,7 +613,7 @@ void LaunchAddBiasTranspose(
|
|||
const float* input, const float* biases, float* output,
|
||||
bool /*enable_half4*/, const int v_head_size, float* qkv_add_bias, int total_matrix_count) {
|
||||
total_matrix_count = std::max(num_matrices, total_matrix_count);
|
||||
if (0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) {
|
||||
if (0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4))) {
|
||||
const int H = qk_head_size / 4;
|
||||
const float4* input2 = reinterpret_cast<const float4*>(input);
|
||||
const float4* biases2 = reinterpret_cast<const float4*>(biases);
|
||||
|
|
@ -622,7 +623,7 @@ void LaunchAddBiasTranspose(
|
|||
stream, num_matrices, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, H, input2, biases2, output2,
|
||||
qkv_add_bias2, v_head_size / 4, total_matrix_count);
|
||||
} else if (0 == (qk_head_size & 1) && 0 == (v_head_size & 1)) {
|
||||
} else if (0 == (qk_head_size & 1) && (v_head_size == -1 || 0 == (v_head_size & 1))) {
|
||||
const int H = qk_head_size / 2;
|
||||
const float2* input2 = reinterpret_cast<const float2*>(input);
|
||||
const float2* biases2 = reinterpret_cast<const float2*>(biases);
|
||||
|
|
@ -654,7 +655,7 @@ void InvokeAddBiasTransposeTrt(
|
|||
const dim3 block(head_size, num_heads, 1);
|
||||
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(query, key, value, biases, output);
|
||||
} else {
|
||||
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
|
||||
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
|
||||
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, query, key, value, biases, output);
|
||||
}
|
||||
} else { // cross attention
|
||||
|
|
@ -666,7 +667,7 @@ void InvokeAddBiasTransposeTrt(
|
|||
const dim3 block(head_size, num_heads, 1);
|
||||
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(query, biases, output);
|
||||
} else {
|
||||
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
|
||||
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
|
||||
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, query, biases, output);
|
||||
}
|
||||
}
|
||||
|
|
@ -680,7 +681,7 @@ void InvokeAddBiasTransposeTrt(
|
|||
const dim3 block(head_size, num_heads, 1);
|
||||
AddBiasTransposeTrtKV<T><<<grid, block, 0, stream>>>(key, value, biases, packed_kv);
|
||||
} else {
|
||||
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
|
||||
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
|
||||
AddBiasTransposeTrtKVLarge<T><<<grid, block, 0, stream>>>(head_size, key, value, biases, packed_kv);
|
||||
}
|
||||
}
|
||||
|
|
@ -737,6 +738,7 @@ void InvokeAddBias(
|
|||
const int batch_size, const int sequence_length, const int kv_sequence_length,
|
||||
const int num_heads, const int head_size, const int v_head_size,
|
||||
const T* biases, const T* query, const T* key, const T* value, T* q, T* k, T* v) {
|
||||
assert(num_heads <= max_threads_per_block);
|
||||
constexpr int num_matrices = 1;
|
||||
// Q
|
||||
{
|
||||
|
|
@ -745,7 +747,7 @@ void InvokeAddBias(
|
|||
const dim3 block(head_size, num_heads, 1);
|
||||
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(query, biases, q);
|
||||
} else {
|
||||
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
|
||||
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
|
||||
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, query, biases, q);
|
||||
}
|
||||
}
|
||||
|
|
@ -758,7 +760,7 @@ void InvokeAddBias(
|
|||
const dim3 block(head_size, num_heads, 1);
|
||||
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(key, biases_k, k);
|
||||
} else {
|
||||
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
|
||||
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
|
||||
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, key, biases_k, k);
|
||||
}
|
||||
}
|
||||
|
|
@ -772,7 +774,7 @@ void InvokeAddBias(
|
|||
const dim3 block(v_head_size, num_heads, 1);
|
||||
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(value, biases_v, v);
|
||||
} else {
|
||||
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
|
||||
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
|
||||
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(v_head_size, value, biases_v, v);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -203,6 +203,7 @@ class SymbolicShapeInference:
|
|||
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
|
||||
"GroupNorm": self._infer_GroupNorm,
|
||||
"BiasSplitGelu": self._infer_BiasSplitGelu,
|
||||
"NhwcConv": self._infer_NhwcConv,
|
||||
}
|
||||
self.aten_op_dispatcher_ = {
|
||||
"embedding": self._infer_Gather,
|
||||
|
|
@ -442,6 +443,7 @@ class SymbolicShapeInference:
|
|||
"MultiHeadAttention",
|
||||
"GroupNorm",
|
||||
"BiasSplitGelu",
|
||||
"NhwcConv",
|
||||
]
|
||||
|
||||
if not skip_infer:
|
||||
|
|
@ -623,13 +625,13 @@ class SymbolicShapeInference:
|
|||
def _new_symbolic_shape(self, rank, node, out_idx=0):
|
||||
return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)]
|
||||
|
||||
def _compute_conv_pool_shape(self, node):
|
||||
def _compute_conv_pool_shape(self, node, channels_last=False):
|
||||
sympy_shape = self._get_sympy_shape(node, 0)
|
||||
if len(node.input) > 1:
|
||||
W_shape = self._get_sympy_shape(node, 1)
|
||||
rank = len(W_shape) - 2 # number of spatial axes
|
||||
kernel_shape = W_shape[-rank:]
|
||||
sympy_shape[1] = W_shape[0]
|
||||
kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:]
|
||||
sympy_shape[3 if channels_last else 1] = W_shape[0]
|
||||
else:
|
||||
W_shape = None
|
||||
kernel_shape = get_attribute(node, "kernel_shape")
|
||||
|
|
@ -638,13 +640,17 @@ class SymbolicShapeInference:
|
|||
assert len(sympy_shape) == rank + 2
|
||||
|
||||
# only need to symbolic shape inference if input has symbolic dims in spatial axes
|
||||
is_symbolic_dims = [not is_literal(i) for i in sympy_shape[-rank:]]
|
||||
spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:]
|
||||
is_symbolic_dims = [not is_literal(i) for i in spatial_shape]
|
||||
|
||||
if not any(is_symbolic_dims):
|
||||
shape = get_shape_from_value_info(self.known_vi_[node.output[0]])
|
||||
if len(shape) > 0:
|
||||
assert len(sympy_shape) == len(shape)
|
||||
sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
|
||||
if channels_last:
|
||||
sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]]
|
||||
else:
|
||||
sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
|
||||
return sympy_shape
|
||||
|
||||
dilations = get_attribute(node, "dilations", [1] * rank)
|
||||
|
|
@ -675,7 +681,7 @@ class SymbolicShapeInference:
|
|||
|
||||
ceil_mode = get_attribute(node, "ceil_mode", 0)
|
||||
for i in range(rank):
|
||||
effective_input_size = sympy_shape[-rank + i]
|
||||
effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)]
|
||||
if len(total_pads) > 0:
|
||||
effective_input_size = effective_input_size + total_pads[i]
|
||||
if ceil_mode:
|
||||
|
|
@ -684,7 +690,7 @@ class SymbolicShapeInference:
|
|||
)
|
||||
else:
|
||||
strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i]
|
||||
sympy_shape[-rank + i] = strided_kernel_positions + 1
|
||||
sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1
|
||||
return sympy_shape
|
||||
|
||||
def _check_merged_dims(self, dims, allow_broadcast=True):
|
||||
|
|
@ -918,6 +924,18 @@ class SymbolicShapeInference:
|
|||
)
|
||||
)
|
||||
|
||||
def _infer_NhwcConv(self, node):
|
||||
sympy_shape = self._compute_conv_pool_shape(node, channels_last=True)
|
||||
self._update_computed_dims(sympy_shape)
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(
|
||||
helper.make_tensor_value_info(
|
||||
node.output[0],
|
||||
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
||||
get_shape_from_sympy_shape(sympy_shape),
|
||||
)
|
||||
)
|
||||
|
||||
def _infer_Einsum(self, node):
|
||||
# ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275
|
||||
equation = get_attribute(node, "equation")
|
||||
|
|
@ -2459,6 +2477,7 @@ class SymbolicShapeInference:
|
|||
all_shapes_inferred = symbolic_shape_inference._infer_impl()
|
||||
symbolic_shape_inference._update_output_from_vi()
|
||||
if not all_shapes_inferred:
|
||||
onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True)
|
||||
raise Exception("Incomplete symbolic shape inference")
|
||||
return symbolic_shape_inference.out_mp_
|
||||
|
||||
|
|
|
|||
90
onnxruntime/python/tools/transformers/fusion_nhwc_conv.py
Normal file
90
onnxruntime/python/tools/transformers/fusion_nhwc_conv.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from logging import getLogger
|
||||
from typing import List
|
||||
|
||||
from fusion_base import Fusion
|
||||
from onnx import TensorProto, helper, numpy_helper
|
||||
from onnx_model import OnnxModel
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class FusionNhwcConv(Fusion):
|
||||
"""Convert Conv to NhwcConv"""
|
||||
|
||||
def __init__(self, model: OnnxModel, update_weight=False):
|
||||
super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv")
|
||||
self.update_weight = update_weight
|
||||
|
||||
def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
|
||||
"""Append a Transpose node after an input"""
|
||||
node_name = self.model.create_node_name("Transpose")
|
||||
|
||||
if output_name is None:
|
||||
output_name = node_name + "_out" + "-" + input_name
|
||||
|
||||
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
|
||||
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
|
||||
|
||||
return transpose_node
|
||||
|
||||
def fuse(self, conv, input_name_to_nodes, output_name_to_node):
|
||||
# Add Transpose node to convert input from NCHW to NHWC
|
||||
input_transpose_node = self.create_transpose_node(conv.input[0], [0, 2, 3, 1])
|
||||
|
||||
nhwc_conv_input = input_transpose_node.output[0]
|
||||
|
||||
# Create a tensor for transposed weights (already in NHWC format).
|
||||
node_name = self.model.create_node_name("NhwcConv")
|
||||
|
||||
# Make sure the weights is 4D
|
||||
weight_tensor = self.model.get_initializer(conv.input[1])
|
||||
if weight_tensor is None:
|
||||
return
|
||||
weight = numpy_helper.to_array(weight_tensor)
|
||||
if len(weight.shape) != 4:
|
||||
return
|
||||
|
||||
if self.update_weight:
|
||||
# Transpose weights from NCHW to NHWC
|
||||
weight = weight.transpose(0, 2, 3, 1)
|
||||
|
||||
weight_name = node_name + "_weight_NHWC"
|
||||
nhwc_weight = helper.make_tensor(
|
||||
name=weight_name,
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=list(weight.shape),
|
||||
vals=weight.flatten().tolist(),
|
||||
)
|
||||
self.model.add_initializer(nhwc_weight, self.this_graph_name)
|
||||
weight_transpose_node = None
|
||||
else:
|
||||
weight_transpose_node = self.create_transpose_node(conv.input[1], [0, 2, 3, 1])
|
||||
weight_name = weight_transpose_node.output[0]
|
||||
|
||||
nhwc_output_name = node_name + "_out" + "-" + conv.output[0]
|
||||
nhwc_conv = helper.make_node(
|
||||
"NhwcConv",
|
||||
inputs=[nhwc_conv_input, weight_name] + conv.input[2:],
|
||||
outputs=[nhwc_output_name],
|
||||
name=node_name + "-" + conv.name,
|
||||
)
|
||||
nhwc_conv.attribute.extend(conv.attribute)
|
||||
nhwc_conv.domain = "com.microsoft"
|
||||
|
||||
output_transpose_node = self.create_transpose_node(nhwc_conv.output[0], [0, 3, 1, 2], conv.output[0])
|
||||
|
||||
self.nodes_to_remove.append(conv)
|
||||
|
||||
nodes_to_add = [input_transpose_node, nhwc_conv, output_transpose_node]
|
||||
if weight_transpose_node:
|
||||
nodes_to_add.append(weight_transpose_node)
|
||||
for node in nodes_to_add:
|
||||
self.node_name_to_graph_name[node.name] = self.this_graph_name
|
||||
self.nodes_to_add.extend(nodes_to_add)
|
||||
|
||||
self.increase_counter("NhwcConv")
|
||||
|
|
@ -119,16 +119,15 @@ class FusionReshape(Fusion):
|
|||
shape_nodes.extend([path2[-1], path3[-1]])
|
||||
shape.append(-1)
|
||||
elif len(concat_node.input) > 2:
|
||||
concat_2 = self.model.get_initializer(concat_node.input[2])
|
||||
if concat_2 is None:
|
||||
concat_value = self.model.get_constant_value(concat_node.input[2])
|
||||
if concat_value is None:
|
||||
return
|
||||
concat_value = numpy_helper.to_array(concat_2)
|
||||
if isinstance(concat_value, np.ndarray):
|
||||
shape.extend(concat_value.tolist())
|
||||
else:
|
||||
shape.append(concat_value)
|
||||
|
||||
if len(concat_node.input) == 4 and self.model.get_initializer(concat_node.input[3]) is None:
|
||||
if len(concat_node.input) == 4 and self.model.get_constant_value(concat_node.input[3]) is None:
|
||||
if -1 in shape:
|
||||
return
|
||||
|
||||
|
|
|
|||
81
onnxruntime/python/tools/transformers/fusion_transpose.py
Normal file
81
onnxruntime/python/tools/transformers/fusion_transpose.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from logging import getLogger
|
||||
from typing import Dict, List
|
||||
|
||||
from fusion_base import Fusion
|
||||
from fusion_utils import FusionUtils
|
||||
from onnx import NodeProto, helper
|
||||
from onnx_model import OnnxModel
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class FusionTranspose(Fusion):
|
||||
def __init__(self, model: OnnxModel):
|
||||
super().__init__(model, "Transpose", "Transpose")
|
||||
|
||||
def fuse(
|
||||
self,
|
||||
transpose_node: NodeProto,
|
||||
input_name_to_nodes: Dict[str, List[NodeProto]],
|
||||
output_name_to_node: Dict[str, NodeProto],
|
||||
):
|
||||
"""
|
||||
Case 1:
|
||||
(input)-->Transpose(perm=a)-->Transpose(perm=b)-->
|
||||
After:
|
||||
(input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
|
||||
|
|
||||
+----->Transpose(perm=a*b)-->
|
||||
|
||||
Case 2 (Cast has only one child):
|
||||
(input)-->Transpose(perm=a)--> Cast -->Transpose(perm=b)-->
|
||||
After:
|
||||
(input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
|
||||
|
|
||||
+----->Cast --> Transpose(perm=a*b)-->
|
||||
|
||||
|
||||
"""
|
||||
transpose_b = transpose_node
|
||||
if transpose_b.input[0] not in output_name_to_node:
|
||||
return
|
||||
|
||||
transpose_a = output_name_to_node[transpose_b.input[0]]
|
||||
if transpose_a.op_type != "Cast":
|
||||
cast_node = None
|
||||
else:
|
||||
cast_node = transpose_a
|
||||
|
||||
cast_children = self.model.get_children(cast_node, input_name_to_nodes)
|
||||
if cast_children and len(cast_children) > 1:
|
||||
return
|
||||
transpose_a = output_name_to_node[cast_node.input[0]]
|
||||
|
||||
if transpose_a.op_type != "Transpose":
|
||||
return
|
||||
|
||||
permutation = OnnxModel.get_node_attribute(transpose_b, "perm")
|
||||
assert isinstance(permutation, list)
|
||||
|
||||
parent_permutation = OnnxModel.get_node_attribute(transpose_a, "perm")
|
||||
assert isinstance(parent_permutation, list)
|
||||
|
||||
assert len(parent_permutation) == len(permutation)
|
||||
|
||||
output_permutation = []
|
||||
for j, index in enumerate(permutation):
|
||||
output_permutation.append(parent_permutation[index])
|
||||
|
||||
if cast_node is None:
|
||||
if FusionUtils.skip_parent(self.model, transpose_b, transpose_a, input_name_to_nodes):
|
||||
self.nodes_to_remove.append(transpose_a)
|
||||
else:
|
||||
if FusionUtils.skip_parent(self.model, cast_node, transpose_a, input_name_to_nodes):
|
||||
self.nodes_to_remove.append(transpose_a)
|
||||
transpose_b.ClearField("attribute")
|
||||
transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)])
|
||||
|
|
@ -73,6 +73,32 @@ class FusionUtils:
|
|||
self.model.remove_node(node)
|
||||
self.model.replace_input_of_all_nodes(output_name, input_name)
|
||||
|
||||
@staticmethod
|
||||
def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes):
|
||||
"""
|
||||
Before:
|
||||
(input)-->parent-->node-->(output)
|
||||
After:
|
||||
(input)-->parent-->
|
||||
|
|
||||
+----->node-->(output)
|
||||
|
||||
This function returns a flag about whether the parent node can be removed.
|
||||
Note that this function assumes the node has first input links from parent!
|
||||
"""
|
||||
parent_can_be_removed = False
|
||||
input_name_to_nodes[node.input[0]].remove(node)
|
||||
# We can remove the first Transpose if its output is not used (linked to graph output or other nodes) anymore.
|
||||
if len(input_name_to_nodes[node.input[0]]) == 0 and not model.find_graph_output(
|
||||
node.input[0]
|
||||
): # checks main graph output. TODO: deal with subgraph
|
||||
parent_can_be_removed = True
|
||||
# self.nodes_to_remove.append(transpose_a)
|
||||
|
||||
input_name_to_nodes[parent_node.input[0]].append(node)
|
||||
node.input[0] = parent_node.input[0]
|
||||
return parent_can_be_removed
|
||||
|
||||
@staticmethod
|
||||
def check_node_attribute(node, attribute_name: str, expected_value, default_value=None):
|
||||
"""Verify that a node has expected value for an attribute.
|
||||
|
|
@ -228,7 +254,10 @@ class FusionUtils:
|
|||
graph_output_names = set(self.model.get_graphs_output_names())
|
||||
for node in nodes_to_remove:
|
||||
if bool(set(node.output) & graph_output_names):
|
||||
if not bool(set(node.input) & graph_input_names):
|
||||
if (
|
||||
not bool(set(node.input) & graph_input_names)
|
||||
and len(self.model.input_name_to_nodes()[node.input[0]]) == 1 # parent has only one child
|
||||
):
|
||||
self.model.replace_output_of_all_nodes(node.input[0], node.output[0])
|
||||
else:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ def get_ort_pipeline(model_name: str, directory: str, provider: str, disable_saf
|
|||
return pipe
|
||||
|
||||
|
||||
def get_torch_pipeline(model_name: str, disable_channels_last: bool, disable_safety_checker: bool):
|
||||
def get_torch_pipeline(model_name: str, disable_safety_checker: bool):
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from torch import channels_last, float16
|
||||
|
||||
|
|
@ -70,8 +70,7 @@ def get_torch_pipeline(model_name: str, disable_channels_last: bool, disable_saf
|
|||
model_name, torch_dtype=float16, revision="fp16", use_auth_token=True
|
||||
).to("cuda")
|
||||
|
||||
if not disable_channels_last:
|
||||
pipe.unet.to(memory_format=channels_last) # in-place operation
|
||||
pipe.unet.to(memory_format=channels_last) # in-place operation
|
||||
|
||||
if disable_safety_checker:
|
||||
pipe.safety_checker = None
|
||||
|
|
@ -144,7 +143,7 @@ def run_ort(model_name: str, directory: str, provider: str, batch_size: int, dis
|
|||
run_ort_pipeline(pipe, batch_size, image_filename_prefix)
|
||||
|
||||
|
||||
def run_torch(model_name: str, batch_size: int, disable_channels_last: bool, disable_safety_checker: bool):
|
||||
def run_torch(model_name: str, batch_size: int, disable_safety_checker: bool):
|
||||
import torch
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
|
@ -154,13 +153,11 @@ def run_torch(model_name: str, batch_size: int, disable_channels_last: bool, dis
|
|||
torch.set_grad_enabled(False)
|
||||
|
||||
load_start = time.time()
|
||||
pipe = get_torch_pipeline(model_name, disable_channels_last, disable_safety_checker)
|
||||
pipe = get_torch_pipeline(model_name, disable_safety_checker)
|
||||
load_end = time.time()
|
||||
print(f"Model loading took {load_end - load_start} seconds")
|
||||
|
||||
image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, disable_safety_checker) + (
|
||||
"" if disable_channels_last else "_channels_last"
|
||||
)
|
||||
image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, disable_safety_checker)
|
||||
with torch.inference_mode():
|
||||
run_torch_pipeline(pipe, batch_size, image_filename_prefix)
|
||||
|
||||
|
|
@ -196,15 +193,6 @@ def parse_arguments():
|
|||
help="Directory of saved onnx pipeline. It could be output directory of optimize_pipeline.py.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--disable_channels_last",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Disable channels last for torch. It will be ignored for onnxruntime engine",
|
||||
)
|
||||
parser.set_defaults(disable_channels_last=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable_safety_checker",
|
||||
required=False,
|
||||
|
|
@ -237,7 +225,7 @@ def main():
|
|||
provider = "CUDAExecutionProvider" # TODO: use ["CUDAExecutionProvider", "CPUExecutionProvider"] in diffuers
|
||||
run_ort(sd_model, args.pipeline, provider, args.batch_size, not args.enable_safety_checker)
|
||||
else:
|
||||
run_torch(sd_model, args.batch_size, args.disable_channels_last, not args.enable_safety_checker)
|
||||
run_torch(sd_model, args.batch_size, not args.enable_safety_checker)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -11,18 +11,15 @@
|
|||
# huggingface-cli login
|
||||
# wget https://raw.githubusercontent.com/huggingface/diffusers/v0.12.1/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path $ONNX_ROOT/stable-diffusion-v1-5-fp32
|
||||
# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path stabilityai/stable-diffusion-2-1 --output_path $ONNX_ROOT/stable-diffusion-v2-1-fp32
|
||||
# Note that this script might not be compatible with older or newer version of diffusers/transformers. It is because fusion script need change accordingly when onnx graph is changed.
|
||||
# Note that this script might not be compatible with older or newer version of diffusers.
|
||||
|
||||
# Then you can use this script to convert them to float16 like the following:
|
||||
# python optimize_pipeline.py -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16
|
||||
# python optimize_pipeline.py -i $ONNX_ROOT/stable-diffusion-v2-1-fp32 -o $ONNX_ROOT/stable-diffusion-v2-1-fp16 --float16
|
||||
# Or
|
||||
# pip install -U onnxruntime-gpu >= 1.14
|
||||
# python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16
|
||||
# python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i $ONNX_ROOT/stable-diffusion-v2-1-fp32 -o $ONNX_ROOT/stable-diffusion-v2-1-fp16 --float16
|
||||
|
||||
# Note that float16 model is for CUDA Execution Provider. It might not run in CPU Execution Provider.
|
||||
#
|
||||
# Note that output model is for CUDA Execution Provider. It might not run in CPU Execution Provider.
|
||||
# Stable diffusion 2.1 model will get black images using float16 Attention. It is a known issue that we are working on.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
|
@ -40,7 +37,7 @@ from optimizer import optimize_model # noqa: E402
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def optimize_stable_diffusion_onnx_pipeline(
|
||||
def optimize_sd_pipeline(
|
||||
source_dir: Path, target_dir: Path, overwrite: bool, use_external_data_format: bool, float16: bool
|
||||
):
|
||||
"""Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16.
|
||||
|
|
@ -66,23 +63,18 @@ def optimize_stable_diffusion_onnx_pipeline(
|
|||
raise RuntimeError(message)
|
||||
continue
|
||||
|
||||
num_heads = 0
|
||||
hidden_size = 0
|
||||
|
||||
# Graph fusion before fp16 conversion, otherwise they cannot be fused later.
|
||||
# Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead.
|
||||
logger.info(f"optimize {onnx_model_path}...")
|
||||
|
||||
fusion_options = FusionOptions("unet")
|
||||
# packed kv requires compute capacity >= 7.5 (like T4, A100, RTX 2060~4090. See https://developer.nvidia.com/cuda-gpus)
|
||||
# Suggest to disable it if you are using older GPU like V100, RTX 1060/1070/1080, or using float32 model.
|
||||
fusion_options.enable_packed_kv = float16
|
||||
|
||||
m = optimize_model(
|
||||
str(onnx_model_path),
|
||||
model_type="unet",
|
||||
num_heads=num_heads,
|
||||
hidden_size=hidden_size,
|
||||
num_heads=0, # will be deduced from graph
|
||||
hidden_size=0, # will be deduced from graph
|
||||
opt_level=0,
|
||||
optimization_options=fusion_options,
|
||||
use_gpu=False,
|
||||
|
|
@ -211,7 +203,7 @@ def main():
|
|||
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
|
||||
args = parse_arguments()
|
||||
copy_extra_directory(Path(args.input), Path(args.output), args.overwrite)
|
||||
optimize_stable_diffusion_onnx_pipeline(
|
||||
optimize_sd_pipeline(
|
||||
Path(args.input), Path(args.output), args.overwrite, args.use_external_data_format, args.float16
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
# Install the following package in python 3.10
|
||||
diffusers==0.12.1
|
||||
transformers==4.26.0
|
||||
numpy==1.24.1
|
||||
accelerate==0.15.0
|
||||
onnxruntime-gpu>=1.14
|
||||
onnx==1.13.0
|
||||
coloredlogs
|
||||
packaging==23.0
|
||||
protobuf==3.20.3
|
||||
psutil==5.9.4
|
||||
sympy==1.11.1
|
||||
--extra-index-url https://download.pytorch.org/whl/cu117
|
||||
torch==1.13.1+cu117
|
||||
|
|
@ -128,6 +128,8 @@ class OnnxModel:
|
|||
for graph in self.graphs():
|
||||
if node in graph.node:
|
||||
graph.node.remove(node)
|
||||
return
|
||||
logger.warning("Failed to remove node %s", node) # It might be a bug to hit this line.
|
||||
|
||||
def remove_nodes(self, nodes_to_remove):
|
||||
for node in nodes_to_remove:
|
||||
|
|
@ -182,6 +184,12 @@ class OnnxModel:
|
|||
node.output[j] = new_output_name
|
||||
|
||||
def replace_output_of_all_nodes(self, old_output_name, new_output_name):
|
||||
# This function shall be used carefully. For example:
|
||||
# Add --[old_name]--> Cast ---> [new_name]
|
||||
# |
|
||||
# +----[old_name]--> Transpose -->
|
||||
# If we want to remove the Cast node: replace output of Add to new_name is not enough;
|
||||
# The input of Transpose shall also be updated to new_name.
|
||||
for node in self.model.graph.node:
|
||||
OnnxModel.replace_node_output(node, old_output_name, new_output_name)
|
||||
|
||||
|
|
@ -553,7 +561,9 @@ class OnnxModel:
|
|||
graph_output_names = set(self.get_graphs_output_names())
|
||||
for node in nodes_to_remove:
|
||||
if bool(set(node.output) & graph_output_names):
|
||||
if not bool(set(node.input) & graph_input_names):
|
||||
if (not bool(set(node.input) & graph_input_names)) and len(
|
||||
self.input_name_to_nodes()[node.input[0]]
|
||||
) == 1:
|
||||
self.replace_output_of_all_nodes(node.input[0], node.output[0])
|
||||
else:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -9,8 +9,11 @@ from typing import Optional
|
|||
from fusion_attention_unet import FusionAttentionUnet
|
||||
from fusion_biassplitgelu import FusionBiasSplitGelu
|
||||
from fusion_group_norm import FusionGroupNorm
|
||||
from fusion_nhwc_conv import FusionNhwcConv
|
||||
from fusion_options import FusionOptions
|
||||
from fusion_transpose import FusionTranspose
|
||||
from onnx import ModelProto
|
||||
from onnx_model import OnnxModel
|
||||
from onnx_model_bert import BertOnnxModel
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
@ -30,10 +33,61 @@ class UnetOnnxModel(BertOnnxModel):
|
|||
super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
|
||||
|
||||
def preprocess(self):
|
||||
return
|
||||
self.remove_useless_div()
|
||||
|
||||
def postprocess(self):
|
||||
self.merge_sequential_transpose()
|
||||
self.prune_graph()
|
||||
self.remove_unused_constant()
|
||||
|
||||
def remove_useless_div(self):
|
||||
"""Remove Div by 1"""
|
||||
div_nodes = [node for node in self.nodes() if node.op_type == "Div"]
|
||||
|
||||
nodes_to_remove = []
|
||||
for div in div_nodes:
|
||||
if self.find_constant_input(div, 1.0) == 1:
|
||||
nodes_to_remove.append(div)
|
||||
|
||||
for node in nodes_to_remove:
|
||||
self.replace_input_of_all_nodes(node.output[0], node.input[0])
|
||||
|
||||
if nodes_to_remove:
|
||||
self.remove_nodes(nodes_to_remove)
|
||||
logger.info("Removed %d useless Div (by 1) nodes", len(nodes_to_remove))
|
||||
|
||||
def convert_conv_to_nhwc(self):
|
||||
# Do not update weight here since save external data has a bug
|
||||
conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=False)
|
||||
conv_to_nhwc_conv.apply()
|
||||
|
||||
def merge_sequential_transpose(self):
|
||||
fusion_transpose = FusionTranspose(self)
|
||||
fusion_transpose.apply()
|
||||
|
||||
remove_count = 0
|
||||
nodes = self.get_nodes_by_op_type("Transpose")
|
||||
for node in nodes:
|
||||
permutation = OnnxModel.get_node_attribute(node, "perm")
|
||||
assert isinstance(permutation, list)
|
||||
if permutation != list(range(len(permutation))):
|
||||
continue
|
||||
assert not (
|
||||
self.find_graph_output(node.output[0])
|
||||
or self.find_graph_input(node.input[0])
|
||||
or self.find_graph_output(node.input[0])
|
||||
)
|
||||
|
||||
# Let all children nodes skip current Transpose node and link to its parent
|
||||
# Note that we cannot update parent node output since parent node might have more than one children.
|
||||
self.replace_input_of_all_nodes(node.output[0], node.input[0])
|
||||
|
||||
self.remove_node(node)
|
||||
remove_count += 1
|
||||
|
||||
total = len(fusion_transpose.nodes_to_remove) + remove_count
|
||||
if total:
|
||||
logger.info("Removed %d Transpose nodes", total)
|
||||
|
||||
def optimize(self, options: Optional[FusionOptions] = None):
|
||||
if (options is not None) and not options.enable_shape_inference:
|
||||
|
|
@ -78,7 +132,7 @@ class UnetOnnxModel(BertOnnxModel):
|
|||
# Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
|
||||
self.utils.remove_useless_reshape_nodes()
|
||||
|
||||
self.postprocess()
|
||||
self.convert_conv_to_nhwc()
|
||||
|
||||
if (options is None) or options.enable_bias_skip_layer_norm:
|
||||
# Fuse SkipLayerNormalization and Add Bias before it.
|
||||
|
|
@ -87,6 +141,29 @@ class UnetOnnxModel(BertOnnxModel):
|
|||
if options is not None and options.enable_gelu_approximation:
|
||||
self.gelu_approximation()
|
||||
|
||||
self.remove_unused_constant()
|
||||
self.postprocess()
|
||||
|
||||
logger.info(f"opset version: {self.get_opset_version()}")
|
||||
|
||||
def get_fused_operator_statistics(self):
|
||||
"""
|
||||
Returns node count of fused operators.
|
||||
"""
|
||||
op_count = {}
|
||||
ops = [
|
||||
"Attention",
|
||||
"MultiHeadAttention",
|
||||
"Gelu",
|
||||
"FastGelu",
|
||||
"LayerNormalization",
|
||||
"SkipLayerNormalization",
|
||||
"BiasSplitGelu",
|
||||
"GroupNorm",
|
||||
"NhwcConv",
|
||||
]
|
||||
for op in ops:
|
||||
nodes = self.get_nodes_by_op_type(op)
|
||||
op_count[op] = len(nodes)
|
||||
|
||||
logger.info(f"Optimized operators:{op_count}")
|
||||
return op_count
|
||||
|
|
|
|||
3
setup.py
3
setup.py
|
|
@ -481,9 +481,12 @@ packages = [
|
|||
"onnxruntime.quantization.operators",
|
||||
"onnxruntime.quantization.CalTableFlatBuffers",
|
||||
"onnxruntime.transformers",
|
||||
"onnxruntime.transformers.models.bart",
|
||||
"onnxruntime.transformers.models.bert",
|
||||
"onnxruntime.transformers.models.gpt2",
|
||||
"onnxruntime.transformers.models.longformer",
|
||||
"onnxruntime.transformers.models.t5",
|
||||
"onnxruntime.transformers.models.stable_diffusion",
|
||||
]
|
||||
|
||||
package_data = {"onnxruntime.tools.mobile_helpers": ["*.md", "*.config"]}
|
||||
|
|
|
|||
Loading…
Reference in a new issue