diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 809a076443..c24b6b9be5 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -467,12 +467,21 @@ file(GLOB onnxruntime_python_quantization_cal_table_flatbuffers_src CONFIGURE_DE file(GLOB onnxruntime_python_transformers_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/*.py" ) +file(GLOB onnxruntime_python_transformers_models_bart_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/bart/*.py" +) +file(GLOB onnxruntime_python_transformers_models_bert_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/bert/*.py" +) file(GLOB onnxruntime_python_transformers_models_gpt2_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/gpt2/*.py" ) file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py" ) +file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py" +) file(GLOB onnxruntime_python_transformers_models_t5_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/t5/*.py" ) @@ -526,8 +535,11 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/tools/ort_format_model/ort_flatbuffers_py COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/bart + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/bert COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/gpt2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/longformer + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/stable_diffusion COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/t5 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/operators @@ -606,12 +618,21 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_src} $/onnxruntime/transformers/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_bart_src} + $/onnxruntime/transformers/models/bart/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_bert_src} + $/onnxruntime/transformers/models/bert/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_gpt2_src} $/onnxruntime/transformers/models/gpt2/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_longformer_src} $/onnxruntime/transformers/models/longformer/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_stable_diffusion_src} + $/onnxruntime/transformers/models/stable_diffusion/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_t5_src} $/onnxruntime/transformers/models/t5/ diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index e86736726c..8f271ecfcb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -519,6 +519,7 @@ void InvokeAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const T* input, const T* biases, T* output, T* qkv_add_bias, const int v_head_size, int total_matrix_count) { + assert(num_heads <= max_threads_per_block); const dim3 grid(sequence_length, batch_size, num_matrices); if (qk_head_size * num_heads <= max_threads_per_block) { const dim3 block(qk_head_size, num_heads, 1); @@ -544,7 +545,7 @@ void InvokeAddBiasTranspose( AddBiasTranspose<<>>(input, biases, output); } } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); if (format == 2) { AddBiasTransposeTrtLarge<<>>(qk_head_size, input, biases, output); } else if (format == 1) { @@ -577,7 +578,7 @@ void LaunchAddBiasTranspose( const half* input, const half* biases, half* output, bool enable_half4, const int v_head_size, half* qkv_add_bias, int total_matrix_count) { total_matrix_count = std::max(num_matrices, total_matrix_count); - if (enable_half4 && 0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) { + if (enable_half4 && 0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4))) { const int H = qk_head_size / 4; const int H_v = v_head_size / 4; const Half4* input2 = reinterpret_cast(input); @@ -587,7 +588,7 @@ void LaunchAddBiasTranspose( InvokeAddBiasTranspose(stream, num_matrices, format, max_threads_per_block, batch_size, sequence_length, num_heads, H, input2, biases2, output2, qkv_add_bias2, H_v, total_matrix_count); - } else if (0 == (qk_head_size & 1) && 0 == (v_head_size & 1)) { + } else if (0 == (qk_head_size & 1) && (v_head_size == -1 || 0 == (v_head_size & 1))) { const int H = qk_head_size / 2; const int H_v = v_head_size / 2; const half2* input2 = reinterpret_cast(input); @@ -612,7 +613,7 @@ void LaunchAddBiasTranspose( const float* input, const float* biases, float* output, bool /*enable_half4*/, const int v_head_size, float* qkv_add_bias, int total_matrix_count) { total_matrix_count = std::max(num_matrices, total_matrix_count); - if (0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) { + if (0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4))) { const int H = qk_head_size / 4; const float4* input2 = reinterpret_cast(input); const float4* biases2 = reinterpret_cast(biases); @@ -622,7 +623,7 @@ void LaunchAddBiasTranspose( stream, num_matrices, format, max_threads_per_block, batch_size, sequence_length, num_heads, H, input2, biases2, output2, qkv_add_bias2, v_head_size / 4, total_matrix_count); - } else if (0 == (qk_head_size & 1) && 0 == (v_head_size & 1)) { + } else if (0 == (qk_head_size & 1) && (v_head_size == -1 || 0 == (v_head_size & 1))) { const int H = qk_head_size / 2; const float2* input2 = reinterpret_cast(input); const float2* biases2 = reinterpret_cast(biases); @@ -654,7 +655,7 @@ void InvokeAddBiasTransposeTrt( const dim3 block(head_size, num_heads, 1); AddBiasTransposeTrt<<>>(query, key, value, biases, output); } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); AddBiasTransposeTrtLarge<<>>(head_size, query, key, value, biases, output); } } else { // cross attention @@ -666,7 +667,7 @@ void InvokeAddBiasTransposeTrt( const dim3 block(head_size, num_heads, 1); AddBiasTransposeTrt<<>>(query, biases, output); } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); AddBiasTransposeTrtLarge<<>>(head_size, query, biases, output); } } @@ -680,7 +681,7 @@ void InvokeAddBiasTransposeTrt( const dim3 block(head_size, num_heads, 1); AddBiasTransposeTrtKV<<>>(key, value, biases, packed_kv); } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); AddBiasTransposeTrtKVLarge<<>>(head_size, key, value, biases, packed_kv); } } @@ -737,6 +738,7 @@ void InvokeAddBias( const int batch_size, const int sequence_length, const int kv_sequence_length, const int num_heads, const int head_size, const int v_head_size, const T* biases, const T* query, const T* key, const T* value, T* q, T* k, T* v) { + assert(num_heads <= max_threads_per_block); constexpr int num_matrices = 1; // Q { @@ -745,7 +747,7 @@ void InvokeAddBias( const dim3 block(head_size, num_heads, 1); AddBiasTransposeTrt<<>>(query, biases, q); } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); } } @@ -758,7 +760,7 @@ void InvokeAddBias( const dim3 block(head_size, num_heads, 1); AddBiasTransposeTrt<<>>(key, biases_k, k); } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); AddBiasTransposeTrtLarge<<>>(head_size, key, biases_k, k); } } @@ -772,7 +774,7 @@ void InvokeAddBias( const dim3 block(v_head_size, num_heads, 1); AddBiasTransposeTrt<<>>(value, biases_v, v); } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); AddBiasTransposeTrtLarge<<>>(v_head_size, value, biases_v, v); } } diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index ae320279d7..842d5cd943 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -203,6 +203,7 @@ class SymbolicShapeInference: "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, "GroupNorm": self._infer_GroupNorm, "BiasSplitGelu": self._infer_BiasSplitGelu, + "NhwcConv": self._infer_NhwcConv, } self.aten_op_dispatcher_ = { "embedding": self._infer_Gather, @@ -442,6 +443,7 @@ class SymbolicShapeInference: "MultiHeadAttention", "GroupNorm", "BiasSplitGelu", + "NhwcConv", ] if not skip_infer: @@ -623,13 +625,13 @@ class SymbolicShapeInference: def _new_symbolic_shape(self, rank, node, out_idx=0): return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)] - def _compute_conv_pool_shape(self, node): + def _compute_conv_pool_shape(self, node, channels_last=False): sympy_shape = self._get_sympy_shape(node, 0) if len(node.input) > 1: W_shape = self._get_sympy_shape(node, 1) rank = len(W_shape) - 2 # number of spatial axes - kernel_shape = W_shape[-rank:] - sympy_shape[1] = W_shape[0] + kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:] + sympy_shape[3 if channels_last else 1] = W_shape[0] else: W_shape = None kernel_shape = get_attribute(node, "kernel_shape") @@ -638,13 +640,17 @@ class SymbolicShapeInference: assert len(sympy_shape) == rank + 2 # only need to symbolic shape inference if input has symbolic dims in spatial axes - is_symbolic_dims = [not is_literal(i) for i in sympy_shape[-rank:]] + spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:] + is_symbolic_dims = [not is_literal(i) for i in spatial_shape] if not any(is_symbolic_dims): shape = get_shape_from_value_info(self.known_vi_[node.output[0]]) if len(shape) > 0: assert len(sympy_shape) == len(shape) - sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]] + if channels_last: + sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]] + else: + sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]] return sympy_shape dilations = get_attribute(node, "dilations", [1] * rank) @@ -675,7 +681,7 @@ class SymbolicShapeInference: ceil_mode = get_attribute(node, "ceil_mode", 0) for i in range(rank): - effective_input_size = sympy_shape[-rank + i] + effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)] if len(total_pads) > 0: effective_input_size = effective_input_size + total_pads[i] if ceil_mode: @@ -684,7 +690,7 @@ class SymbolicShapeInference: ) else: strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i] - sympy_shape[-rank + i] = strided_kernel_positions + 1 + sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1 return sympy_shape def _check_merged_dims(self, dims, allow_broadcast=True): @@ -918,6 +924,18 @@ class SymbolicShapeInference: ) ) + def _infer_NhwcConv(self, node): + sympy_shape = self._compute_conv_pool_shape(node, channels_last=True) + self._update_computed_dims(sympy_shape) + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) + def _infer_Einsum(self, node): # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275 equation = get_attribute(node, "equation") @@ -2459,6 +2477,7 @@ class SymbolicShapeInference: all_shapes_inferred = symbolic_shape_inference._infer_impl() symbolic_shape_inference._update_output_from_vi() if not all_shapes_inferred: + onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True) raise Exception("Incomplete symbolic shape inference") return symbolic_shape_inference.out_mp_ diff --git a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py new file mode 100644 index 0000000000..d8ecb65280 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py @@ -0,0 +1,90 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from logging import getLogger +from typing import List + +from fusion_base import Fusion +from onnx import TensorProto, helper, numpy_helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionNhwcConv(Fusion): + """Convert Conv to NhwcConv""" + + def __init__(self, model: OnnxModel, update_weight=False): + super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv") + self.update_weight = update_weight + + def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): + """Append a Transpose node after an input""" + node_name = self.model.create_node_name("Transpose") + + if output_name is None: + output_name = node_name + "_out" + "-" + input_name + + transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name) + transpose_node.attribute.extend([helper.make_attribute("perm", perm)]) + + return transpose_node + + def fuse(self, conv, input_name_to_nodes, output_name_to_node): + # Add Transpose node to convert input from NCHW to NHWC + input_transpose_node = self.create_transpose_node(conv.input[0], [0, 2, 3, 1]) + + nhwc_conv_input = input_transpose_node.output[0] + + # Create a tensor for transposed weights (already in NHWC format). + node_name = self.model.create_node_name("NhwcConv") + + # Make sure the weights is 4D + weight_tensor = self.model.get_initializer(conv.input[1]) + if weight_tensor is None: + return + weight = numpy_helper.to_array(weight_tensor) + if len(weight.shape) != 4: + return + + if self.update_weight: + # Transpose weights from NCHW to NHWC + weight = weight.transpose(0, 2, 3, 1) + + weight_name = node_name + "_weight_NHWC" + nhwc_weight = helper.make_tensor( + name=weight_name, + data_type=TensorProto.FLOAT, + dims=list(weight.shape), + vals=weight.flatten().tolist(), + ) + self.model.add_initializer(nhwc_weight, self.this_graph_name) + weight_transpose_node = None + else: + weight_transpose_node = self.create_transpose_node(conv.input[1], [0, 2, 3, 1]) + weight_name = weight_transpose_node.output[0] + + nhwc_output_name = node_name + "_out" + "-" + conv.output[0] + nhwc_conv = helper.make_node( + "NhwcConv", + inputs=[nhwc_conv_input, weight_name] + conv.input[2:], + outputs=[nhwc_output_name], + name=node_name + "-" + conv.name, + ) + nhwc_conv.attribute.extend(conv.attribute) + nhwc_conv.domain = "com.microsoft" + + output_transpose_node = self.create_transpose_node(nhwc_conv.output[0], [0, 3, 1, 2], conv.output[0]) + + self.nodes_to_remove.append(conv) + + nodes_to_add = [input_transpose_node, nhwc_conv, output_transpose_node] + if weight_transpose_node: + nodes_to_add.append(weight_transpose_node) + for node in nodes_to_add: + self.node_name_to_graph_name[node.name] = self.this_graph_name + self.nodes_to_add.extend(nodes_to_add) + + self.increase_counter("NhwcConv") diff --git a/onnxruntime/python/tools/transformers/fusion_reshape.py b/onnxruntime/python/tools/transformers/fusion_reshape.py index 75caa255b1..853038f746 100644 --- a/onnxruntime/python/tools/transformers/fusion_reshape.py +++ b/onnxruntime/python/tools/transformers/fusion_reshape.py @@ -119,16 +119,15 @@ class FusionReshape(Fusion): shape_nodes.extend([path2[-1], path3[-1]]) shape.append(-1) elif len(concat_node.input) > 2: - concat_2 = self.model.get_initializer(concat_node.input[2]) - if concat_2 is None: + concat_value = self.model.get_constant_value(concat_node.input[2]) + if concat_value is None: return - concat_value = numpy_helper.to_array(concat_2) if isinstance(concat_value, np.ndarray): shape.extend(concat_value.tolist()) else: shape.append(concat_value) - if len(concat_node.input) == 4 and self.model.get_initializer(concat_node.input[3]) is None: + if len(concat_node.input) == 4 and self.model.get_constant_value(concat_node.input[3]) is None: if -1 in shape: return diff --git a/onnxruntime/python/tools/transformers/fusion_transpose.py b/onnxruntime/python/tools/transformers/fusion_transpose.py new file mode 100644 index 0000000000..d92ddd5f8e --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_transpose.py @@ -0,0 +1,81 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from logging import getLogger +from typing import Dict, List + +from fusion_base import Fusion +from fusion_utils import FusionUtils +from onnx import NodeProto, helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionTranspose(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "Transpose", "Transpose") + + def fuse( + self, + transpose_node: NodeProto, + input_name_to_nodes: Dict[str, List[NodeProto]], + output_name_to_node: Dict[str, NodeProto], + ): + """ + Case 1: + (input)-->Transpose(perm=a)-->Transpose(perm=b)--> + After: + (input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore) + | + +----->Transpose(perm=a*b)--> + + Case 2 (Cast has only one child): + (input)-->Transpose(perm=a)--> Cast -->Transpose(perm=b)--> + After: + (input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore) + | + +----->Cast --> Transpose(perm=a*b)--> + + + """ + transpose_b = transpose_node + if transpose_b.input[0] not in output_name_to_node: + return + + transpose_a = output_name_to_node[transpose_b.input[0]] + if transpose_a.op_type != "Cast": + cast_node = None + else: + cast_node = transpose_a + + cast_children = self.model.get_children(cast_node, input_name_to_nodes) + if cast_children and len(cast_children) > 1: + return + transpose_a = output_name_to_node[cast_node.input[0]] + + if transpose_a.op_type != "Transpose": + return + + permutation = OnnxModel.get_node_attribute(transpose_b, "perm") + assert isinstance(permutation, list) + + parent_permutation = OnnxModel.get_node_attribute(transpose_a, "perm") + assert isinstance(parent_permutation, list) + + assert len(parent_permutation) == len(permutation) + + output_permutation = [] + for j, index in enumerate(permutation): + output_permutation.append(parent_permutation[index]) + + if cast_node is None: + if FusionUtils.skip_parent(self.model, transpose_b, transpose_a, input_name_to_nodes): + self.nodes_to_remove.append(transpose_a) + else: + if FusionUtils.skip_parent(self.model, cast_node, transpose_a, input_name_to_nodes): + self.nodes_to_remove.append(transpose_a) + transpose_b.ClearField("attribute") + transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)]) diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index 8363f2674c..07fdf49033 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -73,6 +73,32 @@ class FusionUtils: self.model.remove_node(node) self.model.replace_input_of_all_nodes(output_name, input_name) + @staticmethod + def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes): + """ + Before: + (input)-->parent-->node-->(output) + After: + (input)-->parent--> + | + +----->node-->(output) + + This function returns a flag about whether the parent node can be removed. + Note that this function assumes the node has first input links from parent! + """ + parent_can_be_removed = False + input_name_to_nodes[node.input[0]].remove(node) + # We can remove the first Transpose if its output is not used (linked to graph output or other nodes) anymore. + if len(input_name_to_nodes[node.input[0]]) == 0 and not model.find_graph_output( + node.input[0] + ): # checks main graph output. TODO: deal with subgraph + parent_can_be_removed = True + # self.nodes_to_remove.append(transpose_a) + + input_name_to_nodes[parent_node.input[0]].append(node) + node.input[0] = parent_node.input[0] + return parent_can_be_removed + @staticmethod def check_node_attribute(node, attribute_name: str, expected_value, default_value=None): """Verify that a node has expected value for an attribute. @@ -228,7 +254,10 @@ class FusionUtils: graph_output_names = set(self.model.get_graphs_output_names()) for node in nodes_to_remove: if bool(set(node.output) & graph_output_names): - if not bool(set(node.input) & graph_input_names): + if ( + not bool(set(node.input) & graph_input_names) + and len(self.model.input_name_to_nodes()[node.input[0]]) == 1 # parent has only one child + ): self.model.replace_output_of_all_nodes(node.input[0], node.output[0]) else: continue diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 580c5ef4c3..9a00dc8684 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -62,7 +62,7 @@ def get_ort_pipeline(model_name: str, directory: str, provider: str, disable_saf return pipe -def get_torch_pipeline(model_name: str, disable_channels_last: bool, disable_safety_checker: bool): +def get_torch_pipeline(model_name: str, disable_safety_checker: bool): from diffusers import StableDiffusionPipeline from torch import channels_last, float16 @@ -70,8 +70,7 @@ def get_torch_pipeline(model_name: str, disable_channels_last: bool, disable_saf model_name, torch_dtype=float16, revision="fp16", use_auth_token=True ).to("cuda") - if not disable_channels_last: - pipe.unet.to(memory_format=channels_last) # in-place operation + pipe.unet.to(memory_format=channels_last) # in-place operation if disable_safety_checker: pipe.safety_checker = None @@ -144,7 +143,7 @@ def run_ort(model_name: str, directory: str, provider: str, batch_size: int, dis run_ort_pipeline(pipe, batch_size, image_filename_prefix) -def run_torch(model_name: str, batch_size: int, disable_channels_last: bool, disable_safety_checker: bool): +def run_torch(model_name: str, batch_size: int, disable_safety_checker: bool): import torch torch.backends.cudnn.enabled = True @@ -154,13 +153,11 @@ def run_torch(model_name: str, batch_size: int, disable_channels_last: bool, dis torch.set_grad_enabled(False) load_start = time.time() - pipe = get_torch_pipeline(model_name, disable_channels_last, disable_safety_checker) + pipe = get_torch_pipeline(model_name, disable_safety_checker) load_end = time.time() print(f"Model loading took {load_end - load_start} seconds") - image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, disable_safety_checker) + ( - "" if disable_channels_last else "_channels_last" - ) + image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, disable_safety_checker) with torch.inference_mode(): run_torch_pipeline(pipe, batch_size, image_filename_prefix) @@ -196,15 +193,6 @@ def parse_arguments(): help="Directory of saved onnx pipeline. It could be output directory of optimize_pipeline.py.", ) - parser.add_argument( - "-c", - "--disable_channels_last", - required=False, - action="store_true", - help="Disable channels last for torch. It will be ignored for onnxruntime engine", - ) - parser.set_defaults(disable_channels_last=False) - parser.add_argument( "--enable_safety_checker", required=False, @@ -237,7 +225,7 @@ def main(): provider = "CUDAExecutionProvider" # TODO: use ["CUDAExecutionProvider", "CPUExecutionProvider"] in diffuers run_ort(sd_model, args.pipeline, provider, args.batch_size, not args.enable_safety_checker) else: - run_torch(sd_model, args.batch_size, args.disable_channels_last, not args.enable_safety_checker) + run_torch(sd_model, args.batch_size, not args.enable_safety_checker) if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 0979f0d2dd..932be4a19a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -11,18 +11,15 @@ # huggingface-cli login # wget https://raw.githubusercontent.com/huggingface/diffusers/v0.12.1/scripts/convert_stable_diffusion_checkpoint_to_onnx.py # python convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path $ONNX_ROOT/stable-diffusion-v1-5-fp32 -# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path stabilityai/stable-diffusion-2-1 --output_path $ONNX_ROOT/stable-diffusion-v2-1-fp32 -# Note that this script might not be compatible with older or newer version of diffusers/transformers. It is because fusion script need change accordingly when onnx graph is changed. +# Note that this script might not be compatible with older or newer version of diffusers. # Then you can use this script to convert them to float16 like the following: # python optimize_pipeline.py -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16 -# python optimize_pipeline.py -i $ONNX_ROOT/stable-diffusion-v2-1-fp32 -o $ONNX_ROOT/stable-diffusion-v2-1-fp16 --float16 # Or -# pip install -U onnxruntime-gpu >= 1.14 # python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16 -# python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i $ONNX_ROOT/stable-diffusion-v2-1-fp32 -o $ONNX_ROOT/stable-diffusion-v2-1-fp16 --float16 - -# Note that float16 model is for CUDA Execution Provider. It might not run in CPU Execution Provider. +# +# Note that output model is for CUDA Execution Provider. It might not run in CPU Execution Provider. +# Stable diffusion 2.1 model will get black images using float16 Attention. It is a known issue that we are working on. import argparse import logging @@ -40,7 +37,7 @@ from optimizer import optimize_model # noqa: E402 logger = logging.getLogger(__name__) -def optimize_stable_diffusion_onnx_pipeline( +def optimize_sd_pipeline( source_dir: Path, target_dir: Path, overwrite: bool, use_external_data_format: bool, float16: bool ): """Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16. @@ -66,23 +63,18 @@ def optimize_stable_diffusion_onnx_pipeline( raise RuntimeError(message) continue - num_heads = 0 - hidden_size = 0 - # Graph fusion before fp16 conversion, otherwise they cannot be fused later. # Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead. logger.info(f"optimize {onnx_model_path}...") fusion_options = FusionOptions("unet") - # packed kv requires compute capacity >= 7.5 (like T4, A100, RTX 2060~4090. See https://developer.nvidia.com/cuda-gpus) - # Suggest to disable it if you are using older GPU like V100, RTX 1060/1070/1080, or using float32 model. fusion_options.enable_packed_kv = float16 m = optimize_model( str(onnx_model_path), model_type="unet", - num_heads=num_heads, - hidden_size=hidden_size, + num_heads=0, # will be deduced from graph + hidden_size=0, # will be deduced from graph opt_level=0, optimization_options=fusion_options, use_gpu=False, @@ -211,7 +203,7 @@ def main(): coloredlogs.install(fmt="%(funcName)20s: %(message)s") args = parse_arguments() copy_extra_directory(Path(args.input), Path(args.output), args.overwrite) - optimize_stable_diffusion_onnx_pipeline( + optimize_sd_pipeline( Path(args.input), Path(args.output), args.overwrite, args.use_external_data_format, args.float16 ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt new file mode 100644 index 0000000000..8b57df8852 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt @@ -0,0 +1,14 @@ +# Install the following package in python 3.10 +diffusers==0.12.1 +transformers==4.26.0 +numpy==1.24.1 +accelerate==0.15.0 +onnxruntime-gpu>=1.14 +onnx==1.13.0 +coloredlogs +packaging==23.0 +protobuf==3.20.3 +psutil==5.9.4 +sympy==1.11.1 +--extra-index-url https://download.pytorch.org/whl/cu117 +torch==1.13.1+cu117 diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 96c22b5894..42fd4d5909 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -128,6 +128,8 @@ class OnnxModel: for graph in self.graphs(): if node in graph.node: graph.node.remove(node) + return + logger.warning("Failed to remove node %s", node) # It might be a bug to hit this line. def remove_nodes(self, nodes_to_remove): for node in nodes_to_remove: @@ -182,6 +184,12 @@ class OnnxModel: node.output[j] = new_output_name def replace_output_of_all_nodes(self, old_output_name, new_output_name): + # This function shall be used carefully. For example: + # Add --[old_name]--> Cast ---> [new_name] + # | + # +----[old_name]--> Transpose --> + # If we want to remove the Cast node: replace output of Add to new_name is not enough; + # The input of Transpose shall also be updated to new_name. for node in self.model.graph.node: OnnxModel.replace_node_output(node, old_output_name, new_output_name) @@ -553,7 +561,9 @@ class OnnxModel: graph_output_names = set(self.get_graphs_output_names()) for node in nodes_to_remove: if bool(set(node.output) & graph_output_names): - if not bool(set(node.input) & graph_input_names): + if (not bool(set(node.input) & graph_input_names)) and len( + self.input_name_to_nodes()[node.input[0]] + ) == 1: self.replace_output_of_all_nodes(node.input[0], node.output[0]) else: continue diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index feba717bd8..32a9814982 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -9,8 +9,11 @@ from typing import Optional from fusion_attention_unet import FusionAttentionUnet from fusion_biassplitgelu import FusionBiasSplitGelu from fusion_group_norm import FusionGroupNorm +from fusion_nhwc_conv import FusionNhwcConv from fusion_options import FusionOptions +from fusion_transpose import FusionTranspose from onnx import ModelProto +from onnx_model import OnnxModel from onnx_model_bert import BertOnnxModel logger = getLogger(__name__) @@ -30,10 +33,61 @@ class UnetOnnxModel(BertOnnxModel): super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) def preprocess(self): - return + self.remove_useless_div() def postprocess(self): + self.merge_sequential_transpose() self.prune_graph() + self.remove_unused_constant() + + def remove_useless_div(self): + """Remove Div by 1""" + div_nodes = [node for node in self.nodes() if node.op_type == "Div"] + + nodes_to_remove = [] + for div in div_nodes: + if self.find_constant_input(div, 1.0) == 1: + nodes_to_remove.append(div) + + for node in nodes_to_remove: + self.replace_input_of_all_nodes(node.output[0], node.input[0]) + + if nodes_to_remove: + self.remove_nodes(nodes_to_remove) + logger.info("Removed %d useless Div (by 1) nodes", len(nodes_to_remove)) + + def convert_conv_to_nhwc(self): + # Do not update weight here since save external data has a bug + conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=False) + conv_to_nhwc_conv.apply() + + def merge_sequential_transpose(self): + fusion_transpose = FusionTranspose(self) + fusion_transpose.apply() + + remove_count = 0 + nodes = self.get_nodes_by_op_type("Transpose") + for node in nodes: + permutation = OnnxModel.get_node_attribute(node, "perm") + assert isinstance(permutation, list) + if permutation != list(range(len(permutation))): + continue + assert not ( + self.find_graph_output(node.output[0]) + or self.find_graph_input(node.input[0]) + or self.find_graph_output(node.input[0]) + ) + + # Let all children nodes skip current Transpose node and link to its parent + # Note that we cannot update parent node output since parent node might have more than one children. + self.replace_input_of_all_nodes(node.output[0], node.input[0]) + + self.remove_node(node) + remove_count += 1 + + total = len(fusion_transpose.nodes_to_remove) + remove_count + if total: + logger.info("Removed %d Transpose nodes", total) def optimize(self, options: Optional[FusionOptions] = None): if (options is not None) and not options.enable_shape_inference: @@ -78,7 +132,7 @@ class UnetOnnxModel(BertOnnxModel): # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. self.utils.remove_useless_reshape_nodes() - self.postprocess() + self.convert_conv_to_nhwc() if (options is None) or options.enable_bias_skip_layer_norm: # Fuse SkipLayerNormalization and Add Bias before it. @@ -87,6 +141,29 @@ class UnetOnnxModel(BertOnnxModel): if options is not None and options.enable_gelu_approximation: self.gelu_approximation() - self.remove_unused_constant() + self.postprocess() logger.info(f"opset version: {self.get_opset_version()}") + + def get_fused_operator_statistics(self): + """ + Returns node count of fused operators. + """ + op_count = {} + ops = [ + "Attention", + "MultiHeadAttention", + "Gelu", + "FastGelu", + "LayerNormalization", + "SkipLayerNormalization", + "BiasSplitGelu", + "GroupNorm", + "NhwcConv", + ] + for op in ops: + nodes = self.get_nodes_by_op_type(op) + op_count[op] = len(nodes) + + logger.info(f"Optimized operators:{op_count}") + return op_count diff --git a/setup.py b/setup.py index 0c10195dc3..294b975a56 100644 --- a/setup.py +++ b/setup.py @@ -481,9 +481,12 @@ packages = [ "onnxruntime.quantization.operators", "onnxruntime.quantization.CalTableFlatBuffers", "onnxruntime.transformers", + "onnxruntime.transformers.models.bart", + "onnxruntime.transformers.models.bert", "onnxruntime.transformers.models.gpt2", "onnxruntime.transformers.models.longformer", "onnxruntime.transformers.models.t5", + "onnxruntime.transformers.models.stable_diffusion", ] package_data = {"onnxruntime.tools.mobile_helpers": ["*.md", "*.config"]}