UNet fusion and fp16 conversion for stable diffusion (#14248)

Add script to fuse nodes to optimized operators in stable diffusion 1.5
models, and a script to convert fp32 models to fp16 models. Tested with
stable diffusion 1.5.

Note that the optimized model needs onnxruntime-gpu v1.14 (release candidate
will be available soon).

Note: We will update the script to work with latest diffusers and stable
diffusion v2 and v2.1 models.
This commit is contained in:
Tianlei Wu 2023-01-21 10:16:44 -08:00 committed by GitHub
parent e57c312f9d
commit a95fcb4345
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 632 additions and 45 deletions

View file

@ -433,6 +433,7 @@ class SymbolicShapeInference:
"LongformerAttention",
"SkipLayerNormalization",
"PythonOp",
"MultiHeadAttention",
]
if not skip_infer:

View file

@ -0,0 +1,294 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Tuple, Union
import numpy as np
from fusion_base import Fusion
from fusion_utils import NumpyHelper
from onnx import NodeProto, TensorProto, helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionAttentionUnet(Fusion):
"""
Fuse Attention subgraph of UNet into one Attention node.
"""
def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int, is_cross_attention: bool):
super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"])
self.hidden_size = hidden_size
self.num_heads = num_heads
self.is_cross_attention = is_cross_attention
# Flags to show warning only once
self.num_heads_warning = True
self.hidden_size_warning = True
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, layernorm_node: NodeProto) -> Tuple[int, int]:
"""Detect num_heads and hidden_size from a reshape node.
Args:
reshape_q (NodeProto): reshape node for Q
add_q (NodeProto): add node for Q
Returns:
Tuple[int, int]: num_heads and hidden_size
"""
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
q_shape_value = self.model.get_constant_value(reshape_q.input[1])
if q_shape_value is None:
logger.debug(f"{reshape_q.input[1]} is not constant.")
return self.num_heads, self.hidden_size # Fall back to user specified value
if len(q_shape_value) != 4 or q_shape_value[2] <= 0:
logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, -1].")
return self.num_heads, self.hidden_size # Fall back to user specified value
num_heads = q_shape_value[2]
layernorm_bias = self.model.get_initializer(layernorm_node.input[1])
if layernorm_bias is None:
logger.debug(f"{layernorm_node.input[1]} is not initializer.")
return self.num_heads, self.hidden_size # Fall back to user specified value
hidden_size = NumpyHelper.to_array(layernorm_bias).shape[0]
if self.num_heads > 0 and num_heads != self.num_heads:
if self.num_heads_warning:
logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
self.num_heads_warning = False # Do not show the warning more than once
if self.hidden_size > 0 and hidden_size != self.hidden_size:
if self.hidden_size_warning:
logger.warning(
f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
)
self.hidden_size_warning = False # Do not show the warning more than once
return num_heads, hidden_size
def create_attention_node(
self,
q_matmul: NodeProto,
k_matmul: NodeProto,
v_matmul: NodeProto,
num_heads: int,
hidden_size: int,
input: str,
output: str,
) -> Union[NodeProto, None]:
"""Create an Attention node.
Args:
q_matmul (NodeProto): MatMul node in fully connection for Q
k_matmul (NodeProto): MatMul node in fully connection for K
v_matmul (NodeProto): MatMul node in fully connection for V
q_add (NodeProto): Add bias node in fully connection for Q
k_add (NodeProto): Add bias node in fully connection for K
v_add (NodeProto): Add bias node in fully connection for V
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
input (str): input name
output (str): output name
Returns:
Union[NodeProto, None]: the node created or None if failed.
"""
is_self_attention = not self.is_cross_attention
if is_self_attention:
if q_matmul.input[0] != input or k_matmul.input[0] != input or q_matmul.input[0] != input:
logger.debug("q_matmul.input[0] != input or k_matmul.input[0] != input or q_matmul.input[0] != input")
return None
if hidden_size > 0 and (hidden_size % num_heads) != 0:
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
return None
q_weight = self.model.get_initializer(q_matmul.input[1])
k_weight = self.model.get_initializer(k_matmul.input[1])
v_weight = self.model.get_initializer(v_matmul.input[1])
if not (q_weight and k_weight and v_weight):
return None
# Sometimes weights are stored in fp16
if q_weight.data_type == 10:
logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
return None
qw = NumpyHelper.to_array(q_weight)
kw = NumpyHelper.to_array(k_weight)
vw = NumpyHelper.to_array(v_weight)
logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
# assert q and k have same shape as expected
if is_self_attention:
if qw.shape != kw.shape or qw.shape != vw.shape:
return None
qw_in_size = qw.shape[0]
kw_in_size = kw.shape[0]
vw_in_size = vw.shape[0]
assert qw_in_size == kw_in_size == vw_in_size
if hidden_size > 0 and hidden_size != qw_in_size:
raise ValueError(
f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
"Please provide a correct input hidden size or pass in 0"
)
# All the matrices can have the same shape or q, k matrics can have the same shape with v being different
# For 2d weights, the shapes would be [in_size, out_size].
# For 3d weights, shape would be [in_size, a, b] where a*b = out_size
qw_out_size = np.prod(qw.shape[1:])
qkv_weight = np.stack((qw, kw, vw), axis=1)
qkv_weight_dim = 3 * qw_out_size
attention_node_name = self.model.create_node_name("Attention")
weight = helper.make_tensor(
name=attention_node_name + "_qkv_weight",
data_type=TensorProto.FLOAT,
dims=[qw_in_size, qkv_weight_dim],
vals=qkv_weight.flatten().tolist(),
)
self.model.add_initializer(weight, self.this_graph_name)
else:
attention_node_name = self.model.create_node_name("MultiHeadAttention")
# No bias, use zeros
qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
qkv_bias_dim = 3 * hidden_size
bias = helper.make_tensor(
name=attention_node_name + "_qkv_bias",
data_type=TensorProto.FLOAT,
dims=[qkv_bias_dim],
vals=qkv_bias.flatten().tolist(),
)
self.model.add_initializer(bias, self.this_graph_name)
if is_self_attention:
attention_inputs = [
input,
attention_node_name + "_qkv_weight",
attention_node_name + "_qkv_bias",
]
else:
attention_inputs = [
q_matmul.output[0],
k_matmul.output[0],
v_matmul.output[0],
attention_node_name + "_qkv_bias",
]
attention_node = helper.make_node(
"Attention" if is_self_attention else "MultiHeadAttention",
inputs=attention_inputs,
outputs=[output],
name=attention_node_name,
)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
return attention_node
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
node_before_layernorm = self.model.match_parent(
normalize_node, "Add" if self.is_cross_attention else "Reshape", 0
)
if node_before_layernorm is None:
return
root_input = node_before_layernorm.output[0]
children_nodes = input_name_to_nodes[root_input]
skip_add = None
for node in children_nodes:
if node.op_type == "Add": # or node.op_type == "SkipLayerNormalization":
skip_add = node
break
if skip_add is None:
return
another_input = 1 if skip_add.input[0] == root_input else 0
qkv_nodes = self.model.match_parent_path(
skip_add,
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
[another_input, None, None, 0, 0, 0],
)
if qkv_nodes is None:
return
(_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
# No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
if v_nodes is None:
logger.debug("fuse_attention: failed to match v path")
return
(_, _, _, matmul_v) = v_nodes
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
if qk_nodes is not None:
(softmax_qk, mul_qk, matmul_qk) = qk_nodes
else:
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
if qk_nodes is not None:
(softmax_qk, add_zero, mul_qk, matmul_qk) = qk_nodes
else:
logger.debug("fuse_attention: failed to match qk path")
return
q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
if q_nodes is None:
logger.debug("fuse_attention: failed to match q path")
return
(_, _transpose_q, reshape_q, matmul_q) = q_nodes
k_nodes = self.model.match_parent_path(
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0]
)
if k_nodes is None:
logger.debug("fuse_attention: failed to match k path")
return
(_, _, _, _, matmul_k) = k_nodes
attention_last_node = reshape_qkv
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node)
if q_num_heads <= 0:
logger.debug("fuse_attention: failed to detect num_heads")
return
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
new_node = self.create_attention_node(
matmul_q,
matmul_k,
matmul_v,
q_num_heads,
q_hidden_size,
input=normalize_node.output[0],
output=attention_last_node.output[0],
)
if new_node is None:
return
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
# Use prune graph to remove nodes since they are shared by all attention nodes.
self.prune_graph = True

View file

@ -0,0 +1,4 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

View file

@ -0,0 +1,184 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
#
# This script converts stable diffusion onnx models from float to half (mixed) precision for GPU inference.
#
# Before running this script, you need convert checkpoint to float32 onnx models like the following
# git clone https://github.com/huggingface/diffusers
# cd diffusers
# pip install -e .
# huggingface-cli login
# python3 scripts/convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ../stable-diffusion-v1-5
#
# Then you can use this script to convert them to float16 like the following:
# pip3 install -U onnxruntime-gpu >= 1.14
# python3 -m onnxruntime.transformers.models.diffusion.convert_to_fp16 -i ../stable-diffusion-v1-5 -o ../stable-diffusion-v1-5-fp16
# Note that float16 model is intended for CUDA Execution Provider. It might not run in CPU Execution Provider.
import argparse
import logging
import os
import shutil
import sys
from pathlib import Path
import coloredlogs
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from optimizer import optimize_model # noqa: E402
logger = logging.getLogger(__name__)
def convert_to_fp16(source_dir: Path, target_dir: Path, overwrite: bool, use_external_data_format: bool):
"""Convert a model to float16
Args:
source_dir (Path): source directory
target_dir (Path): target directory
overwrite (bool): overwrite if exists
use_external_data_format (bool): save model to two files: one for onnx graph, another for weights
Raises:
RuntimeError: input onnx model does not exist
RuntimeError: output onnx model path existed
"""
dirs_with_onnx = ["vae_encoder", "vae_decoder", "text_encoder", "safety_checker", "unet"]
for name in dirs_with_onnx:
onnx_model_path = source_dir / name / "model.onnx"
if not os.path.exists(onnx_model_path):
raise RuntimeError(f"input onnx model does not exist: {onnx_model_path}")
num_heads = 0
hidden_size = 0
# Graph fusion before fp16 conversion, otherwise they cannot be fused later.
# Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead.
m = optimize_model(
str(onnx_model_path),
model_type="unet",
num_heads=num_heads,
hidden_size=hidden_size,
opt_level=0,
optimization_options=None,
use_gpu=False,
)
# VAE-decoder in fp16 reduced quality thus we exclude it here
if name != "vae_decoder":
m.convert_float_to_float16(op_block_list=["RandomNormalLike", "Resize"])
else:
print("skip convert vae_decoder to fp16.")
optimized_model_path = target_dir / name / "model.onnx"
output_dir = optimized_model_path.parent
if optimized_model_path.exists():
if not overwrite:
raise RuntimeError(f"output onnx model path existed: {optimized_model_path}")
if output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format)
print(f"{onnx_model_path} => {optimized_model_path}")
def copy_extra(source_dir: Path, target_dir: Path, overwrite: bool):
"""Copy extra directory.
Args:
source_dir (Path): source directory
target_dir (Path): target directory
overwrite (bool): overwrite if exists
Raises:
RuntimeError: source path does not exist
RuntimeError: output path exists but overwrite is false.
"""
extra_dirs = ["scheduler", "tokenizer", "feature_extractor"]
for name in extra_dirs:
source_path = source_dir / name
if not os.path.exists(source_path):
raise RuntimeError(f"source path does not exist: {source_path}")
target_path = target_dir / name
if target_path.exists():
if not overwrite:
raise RuntimeError(f"output path existed: {target_path}")
shutil.rmtree(target_path)
shutil.copytree(source_path, target_path)
print(f"{source_path} => {target_path}")
extra_files = ["model_index.json"]
for name in extra_files:
source_path = source_dir / name
if not os.path.exists(source_path):
raise RuntimeError(f"source path does not exist: {source_path}")
target_path = target_dir / name
if target_path.exists():
if not overwrite:
raise RuntimeError(f"output path existed: {target_path}")
os.remove(target_path)
shutil.copyfile(source_path, target_path)
print(f"{source_path} => {target_path}")
def parse_arguments():
"""Parse arguments
Returns:
Namespace: arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input",
required=True,
type=str,
help="Root of input directory of stable diffusion onnx pipeline with float32 models.",
)
parser.add_argument(
"-o",
"--output",
required=True,
type=str,
help="Root of output directory of stable diffusion onnx pipeline with float16 models.",
)
parser.add_argument(
"--overwrite",
required=False,
action="store_true",
help="Overwrite exists files.",
)
parser.set_defaults(overwrite=False)
parser.add_argument(
"-e",
"--use_external_data_format",
required=False,
action="store_true",
help="Onnx model larger than 2GB need to use external data format.",
)
parser.set_defaults(use_external_data_format=False)
args = parser.parse_args()
return args
def main():
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
args = parse_arguments()
copy_extra(Path(args.input), Path(args.output), args.overwrite)
convert_to_fp16(Path(args.input), Path(args.output), args.overwrite, args.use_external_data_format)
main()

View file

@ -39,7 +39,7 @@ class OnnxModel:
try:
if self.shape_infer_helper.infer(dynamic_axis_mapping):
return self.shape_infer_helper
except:
except: # noqa
self.enable_shape_infer = False # disable shape inference to suppress same error message.
print("failed in shape inference", sys.exc_info()[0])
@ -267,7 +267,8 @@ class OnnxModel:
):
"""
Find parent node based on constraints on op_type and index.
When input_index is None, we will find the first parent node based on constraints, and return_indice will be appended the corresponding input index.
When input_index is None, we will find the first parent node based on constraints,
and return_indice will be appended the corresponding input index.
Args:
node (str): current node name.
@ -324,14 +325,16 @@ class OnnxModel:
):
"""
Find a sequence of input edges based on constraints on parent op_type and index.
When input_index is None, we will find the first parent node based on constraints, and return_indice will be appended the corresponding input index.
When input_index is None, we will find the first parent node based on constraints,
and return_indice will be appended the corresponding input index.
Args:
node (str): current node name.
parent_op_types (str): constraint of parent node op_type of each input edge.
parent_input_index (list): constraint of input index of each input edge. None means no constraint.
output_name_to_node (dict): dictionary with output name as key, and node as value.
return_indice (list): a list to append the input index when there is no constraint on input index of an edge.
return_indice (list): a list to append the input index
When there is no constraint on input index of an edge.
Returns:
parents: a list of matched parent node.
@ -526,7 +529,7 @@ class OnnxModel:
"""Remove cast nodes that are not needed: input and output has same data type."""
shape_infer = self.infer_runtime_shape(update=True)
if shape_infer is None:
logger.info(f"Skip removing useless cast nodes since shape inference failed.")
logger.info("Skip removing useless cast nodes since shape inference failed.")
return
def get_data_type(input_or_output_name):
@ -568,19 +571,26 @@ class OnnxModel:
def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs):
"""Convert a model to half (default) or mixed precision.
To use mixed precision, user need specify which graph inputs, outputs, operator type or list of nodes shall keep in float32.
By default, we use symbolic shape inference to get shape and type information. If not, ONNX shape inference will be used.
Note that symbolic/ONNX shape inference might fail, and the conversion might not proceed without shape and type information.
To use mixed precision, user need specify which graph inputs, outputs, operator type
or list of nodes shall keep in float32.
By default, we use symbolic shape inference to get shape and type information.
If not, ONNX shape inference will be used.
Note that symbolic/ONNX shape inference might fail, and the conversion might not proceed
without shape and type information.
Args:
use_symbolic_shape_infer (bool, optional): use symbolic shape inference instead of onnx shape inference. Defaults to True.
keep_io_types (Union[bool, List[str]], optional): It could be boolean or a list of float32 input/output names.
If True, model inputs/outputs should be left as float32. Defaults to False.
use_symbolic_shape_infer (bool, optional): use symbolic shape inference instead of onnx shape inference.
Defaults to True.
keep_io_types (Union[bool, List[str]], optional): boolean or a list of float32 input/output names.
If True, model inputs/outputs should be left as float32.
Defaults to False.
op_block_list (List[str], optional): List of operator types to leave as float32.
Defaults to None, which will use `float16.DEFAULT_OP_BLOCK_LIST` as default.
Defaults to None, which will use `float16.DEFAULT_OP_BLOCK_LIST`.
node_block_list (List[str], optional): List of node names to leave as float32. Defaults to None.
force_fp16_initializers(bool): force converting all float initializers to float16.
Default to false, which will convert only the one needed to avoid precision loss.
Default to false.
min_positive_val (float, optional): minimal positive value. Defaults to 1e-7.
max_finite_val (float, optional): maximal finite value. Defaults to 1e4.
"""
@ -589,7 +599,8 @@ class OnnxModel:
model = self.model
if use_symbolic_shape_infer:
# Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc) are not recognized by onnx shape inference.
# Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc)
# are not recognized by onnx shape inference.
shape_infer_helper = SymbolicShapeInferenceHelper(model)
model = shape_infer_helper.infer_shapes(model, auto_merge=True, guess_output_rank=False)
@ -636,7 +647,8 @@ class OnnxModel:
if prefix in self._node_name_suffix:
suffix = self._node_name_suffix[prefix] + 1
else:
# Check existed node name only once for a prefix as we assume create_node_name is called for every new node in fusion.
# Check existed node name only once for a prefix
# as we assume create_node_name is called for every new node in fusion.
for node in self.nodes():
if node.name and node.name.startswith(prefix):
try:
@ -734,7 +746,7 @@ class OnnxModel:
outputs (list): a list of graph outputs to retain. If it is None, all graph outputs will be kept.
"""
if len(self.graphs()) > 1:
logger.debug(f"Skip prune_graph since graph has subgraph")
logger.debug("Skip prune_graph since graph has subgraph")
return
if outputs is None:
@ -839,7 +851,9 @@ class OnnxModel:
for impacted_node in input_name_to_nodes[output_to_remove]:
if impacted_node not in nodes_to_remove:
logger.debug(
f"it is not safe to remove nodes since output {output_to_remove} is used by {impacted_node}"
"it is not safe to remove nodes since output %s is used by %s",
output_to_remove,
impacted_node,
)
return False
return True
@ -960,14 +974,10 @@ class OnnxModel:
save_model(model, output_path)
def save_model_to_file(self, output_path, use_external_data_format=False, all_tensors_to_one_file=True):
logger.info(f"Sort graphs in topological order")
logger.info("Sort graphs in topological order")
self.topological_sort()
if output_path.endswith(".json"): # Output text for testing small model.
with open(output_path, "w") as out:
out.write(str(model))
else:
OnnxModel.save(self.model, output_path, use_external_data_format, all_tensors_to_one_file)
OnnxModel.save(self.model, output_path, use_external_data_format, all_tensors_to_one_file)
logger.info(f"Model saved to {output_path}")
def get_graph_inputs_excluding_initializers(self):

View file

@ -0,0 +1,81 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Optional
from fusion_attention_unet import FusionAttentionUnet
from fusion_options import FusionOptions
from onnx import ModelProto
from onnx_model_bert import BertOnnxModel
logger = getLogger(__name__)
class UnetOnnxModel(BertOnnxModel):
def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
"""Initialize UNet ONNX Model.
Args:
model (ModelProto): the ONNX model
num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
"""
assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
def preprocess(self):
return
def postprocess(self):
self.prune_graph()
def optimize(self, options: Optional[FusionOptions] = None):
if (options is not None) and not options.enable_shape_inference:
self.disable_shape_inference()
self.utils.remove_identity_nodes()
# Remove cast nodes that having same data type of input and output based on symbolic shape inference.
self.utils.remove_useless_cast_nodes()
if (options is None) or options.enable_layer_norm:
self.fuse_layer_norm()
if (options is None) or options.enable_gelu:
self.fuse_gelu()
self.preprocess()
self.fuse_reshape()
if (options is None) or options.enable_attention:
self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False)
self_attention_fusion.apply()
cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True)
cross_attention_fusion.apply()
if (options is None) or options.enable_skip_layer_norm:
self.fuse_skip_layer_norm()
self.fuse_shape()
# Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
self.utils.remove_useless_reshape_nodes()
self.postprocess()
if (options is None) or options.enable_bias_skip_layer_norm:
# Fuse SkipLayerNormalization and Add Bias before it.
self.fuse_add_bias_skip_layer_norm()
if options is not None and options.enable_gelu_approximation:
self.gelu_approximation()
self.remove_unused_constant()
logger.info(f"opset version: {self.get_opset_version()}")

View file

@ -31,6 +31,7 @@ from onnx_model_bert_keras import BertOnnxModelKeras
from onnx_model_bert_tf import BertOnnxModelTF
from onnx_model_gpt2 import Gpt2OnnxModel
from onnx_model_tnlr import TnlrOnnxModel
from onnx_model_unet import UnetOnnxModel
logger = logging.getLogger(__name__)
@ -47,6 +48,7 @@ MODEL_TYPES = {
0,
), # might add a class for GPT2OnnxModel for TF later.
"tnlr": (TnlrOnnxModel, "pytorch", 1),
"unet": (UnetOnnxModel, "pytorch", 1),
}
@ -139,16 +141,17 @@ def optimize_by_fusion(
model (ModelProto): model object
model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'.
num_heads (int, optional): number of attention heads. Defaults to 0.
0 allows detect the parameter from graph automatically (for model_type "bert" only).
0 allows detect the parameter from graph automatically.
hidden_size (int, optional): hidden size. Defaults to 0.
0 allows detect the parameter from graph automatically (for model_type "bert" only).
optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions. Defaults to None.
0 allows detect the parameter from graph automatically.
optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions.
Defaults to None.
Returns:
object of an optimizer class.
"""
if model_type != "bert" and (num_heads == 0 or hidden_size == 0):
logger.warning("Please specify parameters of num_heads and hidden_size when model_type is not 'bert'")
if model_type not in ["bert", "unet"] and (num_heads == 0 or hidden_size == 0):
logger.warning(f"Please specify parameters of num_heads and hidden_size for model_type {model_type}")
(optimizer_class, producer, _) = MODEL_TYPES[model_type]
@ -198,7 +201,9 @@ def optimize_model(
When opt_level is 0 and only_onnxruntime is False, only python fusion logic is used and onnxruntime is disabled.
When opt_level > 1, use_gpu shall set properly since the optimized graph might contain operators for GPU or CPU only.
When opt_level > 1, use_gpu shall set properly
since the optimized graph might contain operators for GPU or CPU only.
If your model is intended for GPU inference only (especially float16 or mixed precision model), it is recommended to
set use_gpu to be True, otherwise the model is not optimized for GPU inference.
@ -208,24 +213,23 @@ def optimize_model(
input (str): input model path.
model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'.
num_heads (int, optional): number of attention heads. Defaults to 0.
0 allows detect the parameter from graph automatically (for model_type "bert" only).
0 allows detect the parameter from graph automatically.
hidden_size (int, optional): hidden size. Defaults to 0.
0 allows detect the parameter from graph automatically (for model_type "bert" only).
optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions. Defaults to None.
0 allows detect the parameter from graph automatically.
optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions.
Defaults to None.
opt_level (int, optional): onnxruntime graph optimization level (0, 1, 2 or 99) or None. Defaults to None.
When the value is None, default value (1 for bert and gpt2, 0 for other model types) will be used.
When the level > 0, onnxruntime will be used to optimize model first.
When the value is None, default value (1 for bert and gpt2, 0 for other model types) will be used.
When the level > 0, onnxruntime will be used to optimize model first.
use_gpu (bool, optional): use gpu or not for onnxruntime. Defaults to False.
only_onnxruntime (bool, optional): only use onnxruntime to optimize model, and no python fusion. Defaults to False.
only_onnxruntime (bool, optional): only use onnxruntime to optimize model, and no python fusion.
Defaults to False.
Returns:
object of an optimizer class.
"""
assert opt_level is None or opt_level in [0, 1, 2, 99]
if model_type != "bert" and (num_heads == 0 or hidden_size == 0):
logger.warning("Please specify parameters of num_heads and hidden_size when model_type is not 'bert'")
(optimizer_class, _producer, default_opt_level) = MODEL_TYPES[model_type]
if opt_level is None:
@ -300,7 +304,8 @@ def get_fusion_statistics(optimized_model_path: str) -> Dict[str, int]:
def _parse_arguments():
parser = argparse.ArgumentParser(
description="Graph optimization tool for ONNX Runtime. It transforms ONNX graph to use optimized operators for Transformer models."
description="Graph optimization tool for ONNX Runtime."
"It transforms ONNX graph to use optimized operators for Transformer models."
)
parser.add_argument("--input", required=True, type=str, help="input onnx model path")
@ -320,7 +325,9 @@ def _parse_arguments():
required=False,
type=int,
default=0,
help="number of attention heads like 12 for bert-base and 16 for bert-large. Default is 0 to detect automatically for BERT. For other model type, this parameter need specify correctly.",
help="number of attention heads like 12 for bert-base and 16 for bert-large. "
"Default is 0 to detect automatically for BERT."
"For other model type, this parameter need specify correctly.",
)
parser.add_argument(
@ -328,14 +335,17 @@ def _parse_arguments():
required=False,
type=int,
default=0,
help="hidden size like 768 for bert-base and 1024 for bert-large. Default is 0 to detect automatically for BERT. For other model type, this parameter need specify correctly.",
help="hidden size like 768 for bert-base and 1024 for bert-large. "
"Default is 0 to detect automatically for BERT. "
"For other model type, this parameter need specify correctly.",
)
parser.add_argument(
"--input_int32",
required=False,
action="store_true",
help="Use int32 (instead of int64) inputs. It could avoid unnecessary data cast when EmbedLayerNormalization is fused for BERT.",
help="Use int32 (instead of int64) inputs. "
"It could avoid unnecessary data cast when EmbedLayerNormalization is fused for BERT.",
)
parser.set_defaults(input_int32=False)
@ -343,7 +353,8 @@ def _parse_arguments():
"--float16",
required=False,
action="store_true",
help="Convert all weights and nodes in float32 to float16. It has potential loss in precision compared to mixed precision conversion (see convert_float_to_float16).",
help="Convert all weights and nodes in float32 to float16. "
"It has potential loss in precision compared to mixed precision conversion.",
)
parser.set_defaults(float16=False)
@ -374,7 +385,9 @@ def _parse_arguments():
type=int,
choices=[0, 1, 2, 99],
default=None,
help="onnxruntime optimization level. 0 will disable onnxruntime graph optimization. The recommended value is 1. When opt_level > 1 is used, optimized model for GPU might not run in CPU. Level 2 and 99 are intended for --only_onnxruntime.",
help="onnxruntime optimization level. 0 will disable onnxruntime graph optimization. "
"The recommended value is 1. When opt_level > 1 is used, optimized model for GPU might not run in CPU. "
"Level 2 and 99 are intended for --only_onnxruntime.",
)
parser.add_argument(
@ -408,7 +421,7 @@ def main():
logger.debug(f"arguments:{args}")
if os.path.realpath(args.input) == os.path.realpath(args.output):
logger.warning(f"Specified the same input and output path. Note that this may overwrite the original model")
logger.warning("Specified the same input and output path. Note that this may overwrite the original model")
optimization_options = FusionOptions.parse(args)