mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
UNet fusion and fp16 conversion for stable diffusion (#14248)
Add script to fuse nodes to optimized operators in stable diffusion 1.5 models, and a script to convert fp32 models to fp16 models. Tested with stable diffusion 1.5. Note that the optimized model needs onnxruntime-gpu v1.14 (release candidate will be available soon). Note: We will update the script to work with latest diffusers and stable diffusion v2 and v2.1 models.
This commit is contained in:
parent
e57c312f9d
commit
a95fcb4345
7 changed files with 632 additions and 45 deletions
|
|
@ -433,6 +433,7 @@ class SymbolicShapeInference:
|
|||
"LongformerAttention",
|
||||
"SkipLayerNormalization",
|
||||
"PythonOp",
|
||||
"MultiHeadAttention",
|
||||
]
|
||||
|
||||
if not skip_infer:
|
||||
|
|
|
|||
294
onnxruntime/python/tools/transformers/fusion_attention_unet.py
Normal file
294
onnxruntime/python/tools/transformers/fusion_attention_unet.py
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from logging import getLogger
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from fusion_base import Fusion
|
||||
from fusion_utils import NumpyHelper
|
||||
from onnx import NodeProto, TensorProto, helper
|
||||
from onnx_model import OnnxModel
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class FusionAttentionUnet(Fusion):
|
||||
"""
|
||||
Fuse Attention subgraph of UNet into one Attention node.
|
||||
"""
|
||||
|
||||
def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int, is_cross_attention: bool):
|
||||
super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"])
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.is_cross_attention = is_cross_attention
|
||||
|
||||
# Flags to show warning only once
|
||||
self.num_heads_warning = True
|
||||
self.hidden_size_warning = True
|
||||
|
||||
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, layernorm_node: NodeProto) -> Tuple[int, int]:
|
||||
"""Detect num_heads and hidden_size from a reshape node.
|
||||
|
||||
Args:
|
||||
reshape_q (NodeProto): reshape node for Q
|
||||
add_q (NodeProto): add node for Q
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: num_heads and hidden_size
|
||||
"""
|
||||
|
||||
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
|
||||
q_shape_value = self.model.get_constant_value(reshape_q.input[1])
|
||||
if q_shape_value is None:
|
||||
logger.debug(f"{reshape_q.input[1]} is not constant.")
|
||||
return self.num_heads, self.hidden_size # Fall back to user specified value
|
||||
|
||||
if len(q_shape_value) != 4 or q_shape_value[2] <= 0:
|
||||
logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, -1].")
|
||||
return self.num_heads, self.hidden_size # Fall back to user specified value
|
||||
|
||||
num_heads = q_shape_value[2]
|
||||
|
||||
layernorm_bias = self.model.get_initializer(layernorm_node.input[1])
|
||||
if layernorm_bias is None:
|
||||
logger.debug(f"{layernorm_node.input[1]} is not initializer.")
|
||||
return self.num_heads, self.hidden_size # Fall back to user specified value
|
||||
|
||||
hidden_size = NumpyHelper.to_array(layernorm_bias).shape[0]
|
||||
|
||||
if self.num_heads > 0 and num_heads != self.num_heads:
|
||||
if self.num_heads_warning:
|
||||
logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
|
||||
self.num_heads_warning = False # Do not show the warning more than once
|
||||
|
||||
if self.hidden_size > 0 and hidden_size != self.hidden_size:
|
||||
if self.hidden_size_warning:
|
||||
logger.warning(
|
||||
f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
|
||||
)
|
||||
self.hidden_size_warning = False # Do not show the warning more than once
|
||||
|
||||
return num_heads, hidden_size
|
||||
|
||||
def create_attention_node(
|
||||
self,
|
||||
q_matmul: NodeProto,
|
||||
k_matmul: NodeProto,
|
||||
v_matmul: NodeProto,
|
||||
num_heads: int,
|
||||
hidden_size: int,
|
||||
input: str,
|
||||
output: str,
|
||||
) -> Union[NodeProto, None]:
|
||||
"""Create an Attention node.
|
||||
|
||||
Args:
|
||||
q_matmul (NodeProto): MatMul node in fully connection for Q
|
||||
k_matmul (NodeProto): MatMul node in fully connection for K
|
||||
v_matmul (NodeProto): MatMul node in fully connection for V
|
||||
q_add (NodeProto): Add bias node in fully connection for Q
|
||||
k_add (NodeProto): Add bias node in fully connection for K
|
||||
v_add (NodeProto): Add bias node in fully connection for V
|
||||
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
|
||||
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
|
||||
input (str): input name
|
||||
output (str): output name
|
||||
|
||||
Returns:
|
||||
Union[NodeProto, None]: the node created or None if failed.
|
||||
"""
|
||||
is_self_attention = not self.is_cross_attention
|
||||
|
||||
if is_self_attention:
|
||||
if q_matmul.input[0] != input or k_matmul.input[0] != input or q_matmul.input[0] != input:
|
||||
logger.debug("q_matmul.input[0] != input or k_matmul.input[0] != input or q_matmul.input[0] != input")
|
||||
return None
|
||||
|
||||
if hidden_size > 0 and (hidden_size % num_heads) != 0:
|
||||
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
||||
return None
|
||||
|
||||
q_weight = self.model.get_initializer(q_matmul.input[1])
|
||||
k_weight = self.model.get_initializer(k_matmul.input[1])
|
||||
v_weight = self.model.get_initializer(v_matmul.input[1])
|
||||
if not (q_weight and k_weight and v_weight):
|
||||
return None
|
||||
|
||||
# Sometimes weights are stored in fp16
|
||||
if q_weight.data_type == 10:
|
||||
logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
|
||||
return None
|
||||
|
||||
qw = NumpyHelper.to_array(q_weight)
|
||||
kw = NumpyHelper.to_array(k_weight)
|
||||
vw = NumpyHelper.to_array(v_weight)
|
||||
logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
|
||||
|
||||
# assert q and k have same shape as expected
|
||||
if is_self_attention:
|
||||
if qw.shape != kw.shape or qw.shape != vw.shape:
|
||||
return None
|
||||
|
||||
qw_in_size = qw.shape[0]
|
||||
kw_in_size = kw.shape[0]
|
||||
vw_in_size = vw.shape[0]
|
||||
|
||||
assert qw_in_size == kw_in_size == vw_in_size
|
||||
|
||||
if hidden_size > 0 and hidden_size != qw_in_size:
|
||||
raise ValueError(
|
||||
f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
|
||||
"Please provide a correct input hidden size or pass in 0"
|
||||
)
|
||||
|
||||
# All the matrices can have the same shape or q, k matrics can have the same shape with v being different
|
||||
# For 2d weights, the shapes would be [in_size, out_size].
|
||||
# For 3d weights, shape would be [in_size, a, b] where a*b = out_size
|
||||
qw_out_size = np.prod(qw.shape[1:])
|
||||
|
||||
qkv_weight = np.stack((qw, kw, vw), axis=1)
|
||||
qkv_weight_dim = 3 * qw_out_size
|
||||
|
||||
attention_node_name = self.model.create_node_name("Attention")
|
||||
|
||||
weight = helper.make_tensor(
|
||||
name=attention_node_name + "_qkv_weight",
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[qw_in_size, qkv_weight_dim],
|
||||
vals=qkv_weight.flatten().tolist(),
|
||||
)
|
||||
|
||||
self.model.add_initializer(weight, self.this_graph_name)
|
||||
else:
|
||||
attention_node_name = self.model.create_node_name("MultiHeadAttention")
|
||||
|
||||
# No bias, use zeros
|
||||
qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
|
||||
qkv_bias_dim = 3 * hidden_size
|
||||
|
||||
bias = helper.make_tensor(
|
||||
name=attention_node_name + "_qkv_bias",
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[qkv_bias_dim],
|
||||
vals=qkv_bias.flatten().tolist(),
|
||||
)
|
||||
self.model.add_initializer(bias, self.this_graph_name)
|
||||
|
||||
if is_self_attention:
|
||||
attention_inputs = [
|
||||
input,
|
||||
attention_node_name + "_qkv_weight",
|
||||
attention_node_name + "_qkv_bias",
|
||||
]
|
||||
else:
|
||||
attention_inputs = [
|
||||
q_matmul.output[0],
|
||||
k_matmul.output[0],
|
||||
v_matmul.output[0],
|
||||
attention_node_name + "_qkv_bias",
|
||||
]
|
||||
|
||||
attention_node = helper.make_node(
|
||||
"Attention" if is_self_attention else "MultiHeadAttention",
|
||||
inputs=attention_inputs,
|
||||
outputs=[output],
|
||||
name=attention_node_name,
|
||||
)
|
||||
attention_node.domain = "com.microsoft"
|
||||
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
||||
|
||||
return attention_node
|
||||
|
||||
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
||||
node_before_layernorm = self.model.match_parent(
|
||||
normalize_node, "Add" if self.is_cross_attention else "Reshape", 0
|
||||
)
|
||||
if node_before_layernorm is None:
|
||||
return
|
||||
|
||||
root_input = node_before_layernorm.output[0]
|
||||
|
||||
children_nodes = input_name_to_nodes[root_input]
|
||||
skip_add = None
|
||||
for node in children_nodes:
|
||||
if node.op_type == "Add": # or node.op_type == "SkipLayerNormalization":
|
||||
skip_add = node
|
||||
break
|
||||
if skip_add is None:
|
||||
return
|
||||
|
||||
another_input = 1 if skip_add.input[0] == root_input else 0
|
||||
qkv_nodes = self.model.match_parent_path(
|
||||
skip_add,
|
||||
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
||||
[another_input, None, None, 0, 0, 0],
|
||||
)
|
||||
|
||||
if qkv_nodes is None:
|
||||
return
|
||||
|
||||
(_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
|
||||
|
||||
# No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
|
||||
v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
|
||||
if v_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match v path")
|
||||
return
|
||||
(_, _, _, matmul_v) = v_nodes
|
||||
|
||||
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
|
||||
if qk_nodes is not None:
|
||||
(softmax_qk, mul_qk, matmul_qk) = qk_nodes
|
||||
else:
|
||||
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
|
||||
if qk_nodes is not None:
|
||||
(softmax_qk, add_zero, mul_qk, matmul_qk) = qk_nodes
|
||||
else:
|
||||
logger.debug("fuse_attention: failed to match qk path")
|
||||
return
|
||||
|
||||
q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
|
||||
if q_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match q path")
|
||||
return
|
||||
(_, _transpose_q, reshape_q, matmul_q) = q_nodes
|
||||
|
||||
k_nodes = self.model.match_parent_path(
|
||||
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0]
|
||||
)
|
||||
if k_nodes is None:
|
||||
logger.debug("fuse_attention: failed to match k path")
|
||||
return
|
||||
|
||||
(_, _, _, _, matmul_k) = k_nodes
|
||||
|
||||
attention_last_node = reshape_qkv
|
||||
|
||||
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node)
|
||||
if q_num_heads <= 0:
|
||||
logger.debug("fuse_attention: failed to detect num_heads")
|
||||
return
|
||||
|
||||
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
|
||||
new_node = self.create_attention_node(
|
||||
matmul_q,
|
||||
matmul_k,
|
||||
matmul_v,
|
||||
q_num_heads,
|
||||
q_hidden_size,
|
||||
input=normalize_node.output[0],
|
||||
output=attention_last_node.output[0],
|
||||
)
|
||||
if new_node is None:
|
||||
return
|
||||
|
||||
self.nodes_to_add.append(new_node)
|
||||
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
||||
|
||||
self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
|
||||
|
||||
# Use prune graph to remove nodes since they are shared by all attention nodes.
|
||||
self.prune_graph = True
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
|
@ -0,0 +1,184 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# This script converts stable diffusion onnx models from float to half (mixed) precision for GPU inference.
|
||||
#
|
||||
# Before running this script, you need convert checkpoint to float32 onnx models like the following
|
||||
# git clone https://github.com/huggingface/diffusers
|
||||
# cd diffusers
|
||||
# pip install -e .
|
||||
# huggingface-cli login
|
||||
# python3 scripts/convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ../stable-diffusion-v1-5
|
||||
#
|
||||
# Then you can use this script to convert them to float16 like the following:
|
||||
# pip3 install -U onnxruntime-gpu >= 1.14
|
||||
# python3 -m onnxruntime.transformers.models.diffusion.convert_to_fp16 -i ../stable-diffusion-v1-5 -o ../stable-diffusion-v1-5-fp16
|
||||
# Note that float16 model is intended for CUDA Execution Provider. It might not run in CPU Execution Provider.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import coloredlogs
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
from optimizer import optimize_model # noqa: E402
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_to_fp16(source_dir: Path, target_dir: Path, overwrite: bool, use_external_data_format: bool):
|
||||
"""Convert a model to float16
|
||||
|
||||
Args:
|
||||
source_dir (Path): source directory
|
||||
target_dir (Path): target directory
|
||||
overwrite (bool): overwrite if exists
|
||||
use_external_data_format (bool): save model to two files: one for onnx graph, another for weights
|
||||
|
||||
Raises:
|
||||
RuntimeError: input onnx model does not exist
|
||||
RuntimeError: output onnx model path existed
|
||||
"""
|
||||
dirs_with_onnx = ["vae_encoder", "vae_decoder", "text_encoder", "safety_checker", "unet"]
|
||||
for name in dirs_with_onnx:
|
||||
onnx_model_path = source_dir / name / "model.onnx"
|
||||
|
||||
if not os.path.exists(onnx_model_path):
|
||||
raise RuntimeError(f"input onnx model does not exist: {onnx_model_path}")
|
||||
|
||||
num_heads = 0
|
||||
hidden_size = 0
|
||||
|
||||
# Graph fusion before fp16 conversion, otherwise they cannot be fused later.
|
||||
# Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead.
|
||||
m = optimize_model(
|
||||
str(onnx_model_path),
|
||||
model_type="unet",
|
||||
num_heads=num_heads,
|
||||
hidden_size=hidden_size,
|
||||
opt_level=0,
|
||||
optimization_options=None,
|
||||
use_gpu=False,
|
||||
)
|
||||
|
||||
# VAE-decoder in fp16 reduced quality thus we exclude it here
|
||||
if name != "vae_decoder":
|
||||
m.convert_float_to_float16(op_block_list=["RandomNormalLike", "Resize"])
|
||||
else:
|
||||
print("skip convert vae_decoder to fp16.")
|
||||
|
||||
optimized_model_path = target_dir / name / "model.onnx"
|
||||
output_dir = optimized_model_path.parent
|
||||
if optimized_model_path.exists():
|
||||
if not overwrite:
|
||||
raise RuntimeError(f"output onnx model path existed: {optimized_model_path}")
|
||||
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format)
|
||||
print(f"{onnx_model_path} => {optimized_model_path}")
|
||||
|
||||
|
||||
def copy_extra(source_dir: Path, target_dir: Path, overwrite: bool):
|
||||
"""Copy extra directory.
|
||||
|
||||
Args:
|
||||
source_dir (Path): source directory
|
||||
target_dir (Path): target directory
|
||||
overwrite (bool): overwrite if exists
|
||||
|
||||
Raises:
|
||||
RuntimeError: source path does not exist
|
||||
RuntimeError: output path exists but overwrite is false.
|
||||
"""
|
||||
extra_dirs = ["scheduler", "tokenizer", "feature_extractor"]
|
||||
for name in extra_dirs:
|
||||
source_path = source_dir / name
|
||||
if not os.path.exists(source_path):
|
||||
raise RuntimeError(f"source path does not exist: {source_path}")
|
||||
|
||||
target_path = target_dir / name
|
||||
if target_path.exists():
|
||||
if not overwrite:
|
||||
raise RuntimeError(f"output path existed: {target_path}")
|
||||
shutil.rmtree(target_path)
|
||||
|
||||
shutil.copytree(source_path, target_path)
|
||||
print(f"{source_path} => {target_path}")
|
||||
|
||||
extra_files = ["model_index.json"]
|
||||
for name in extra_files:
|
||||
source_path = source_dir / name
|
||||
if not os.path.exists(source_path):
|
||||
raise RuntimeError(f"source path does not exist: {source_path}")
|
||||
|
||||
target_path = target_dir / name
|
||||
if target_path.exists():
|
||||
if not overwrite:
|
||||
raise RuntimeError(f"output path existed: {target_path}")
|
||||
os.remove(target_path)
|
||||
shutil.copyfile(source_path, target_path)
|
||||
print(f"{source_path} => {target_path}")
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""Parse arguments
|
||||
|
||||
Returns:
|
||||
Namespace: arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Root of input directory of stable diffusion onnx pipeline with float32 models.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Root of output directory of stable diffusion onnx pipeline with float16 models.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Overwrite exists files.",
|
||||
)
|
||||
parser.set_defaults(overwrite=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--use_external_data_format",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Onnx model larger than 2GB need to use external data format.",
|
||||
)
|
||||
parser.set_defaults(use_external_data_format=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
|
||||
args = parse_arguments()
|
||||
copy_extra(Path(args.input), Path(args.output), args.overwrite)
|
||||
convert_to_fp16(Path(args.input), Path(args.output), args.overwrite, args.use_external_data_format)
|
||||
|
||||
|
||||
main()
|
||||
|
|
@ -39,7 +39,7 @@ class OnnxModel:
|
|||
try:
|
||||
if self.shape_infer_helper.infer(dynamic_axis_mapping):
|
||||
return self.shape_infer_helper
|
||||
except:
|
||||
except: # noqa
|
||||
self.enable_shape_infer = False # disable shape inference to suppress same error message.
|
||||
print("failed in shape inference", sys.exc_info()[0])
|
||||
|
||||
|
|
@ -267,7 +267,8 @@ class OnnxModel:
|
|||
):
|
||||
"""
|
||||
Find parent node based on constraints on op_type and index.
|
||||
When input_index is None, we will find the first parent node based on constraints, and return_indice will be appended the corresponding input index.
|
||||
When input_index is None, we will find the first parent node based on constraints,
|
||||
and return_indice will be appended the corresponding input index.
|
||||
|
||||
Args:
|
||||
node (str): current node name.
|
||||
|
|
@ -324,14 +325,16 @@ class OnnxModel:
|
|||
):
|
||||
"""
|
||||
Find a sequence of input edges based on constraints on parent op_type and index.
|
||||
When input_index is None, we will find the first parent node based on constraints, and return_indice will be appended the corresponding input index.
|
||||
When input_index is None, we will find the first parent node based on constraints,
|
||||
and return_indice will be appended the corresponding input index.
|
||||
|
||||
Args:
|
||||
node (str): current node name.
|
||||
parent_op_types (str): constraint of parent node op_type of each input edge.
|
||||
parent_input_index (list): constraint of input index of each input edge. None means no constraint.
|
||||
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
||||
return_indice (list): a list to append the input index when there is no constraint on input index of an edge.
|
||||
return_indice (list): a list to append the input index
|
||||
When there is no constraint on input index of an edge.
|
||||
|
||||
Returns:
|
||||
parents: a list of matched parent node.
|
||||
|
|
@ -526,7 +529,7 @@ class OnnxModel:
|
|||
"""Remove cast nodes that are not needed: input and output has same data type."""
|
||||
shape_infer = self.infer_runtime_shape(update=True)
|
||||
if shape_infer is None:
|
||||
logger.info(f"Skip removing useless cast nodes since shape inference failed.")
|
||||
logger.info("Skip removing useless cast nodes since shape inference failed.")
|
||||
return
|
||||
|
||||
def get_data_type(input_or_output_name):
|
||||
|
|
@ -568,19 +571,26 @@ class OnnxModel:
|
|||
|
||||
def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs):
|
||||
"""Convert a model to half (default) or mixed precision.
|
||||
To use mixed precision, user need specify which graph inputs, outputs, operator type or list of nodes shall keep in float32.
|
||||
By default, we use symbolic shape inference to get shape and type information. If not, ONNX shape inference will be used.
|
||||
Note that symbolic/ONNX shape inference might fail, and the conversion might not proceed without shape and type information.
|
||||
To use mixed precision, user need specify which graph inputs, outputs, operator type
|
||||
or list of nodes shall keep in float32.
|
||||
|
||||
By default, we use symbolic shape inference to get shape and type information.
|
||||
If not, ONNX shape inference will be used.
|
||||
|
||||
Note that symbolic/ONNX shape inference might fail, and the conversion might not proceed
|
||||
without shape and type information.
|
||||
|
||||
Args:
|
||||
use_symbolic_shape_infer (bool, optional): use symbolic shape inference instead of onnx shape inference. Defaults to True.
|
||||
keep_io_types (Union[bool, List[str]], optional): It could be boolean or a list of float32 input/output names.
|
||||
If True, model inputs/outputs should be left as float32. Defaults to False.
|
||||
use_symbolic_shape_infer (bool, optional): use symbolic shape inference instead of onnx shape inference.
|
||||
Defaults to True.
|
||||
keep_io_types (Union[bool, List[str]], optional): boolean or a list of float32 input/output names.
|
||||
If True, model inputs/outputs should be left as float32.
|
||||
Defaults to False.
|
||||
op_block_list (List[str], optional): List of operator types to leave as float32.
|
||||
Defaults to None, which will use `float16.DEFAULT_OP_BLOCK_LIST` as default.
|
||||
Defaults to None, which will use `float16.DEFAULT_OP_BLOCK_LIST`.
|
||||
node_block_list (List[str], optional): List of node names to leave as float32. Defaults to None.
|
||||
force_fp16_initializers(bool): force converting all float initializers to float16.
|
||||
Default to false, which will convert only the one needed to avoid precision loss.
|
||||
Default to false.
|
||||
min_positive_val (float, optional): minimal positive value. Defaults to 1e-7.
|
||||
max_finite_val (float, optional): maximal finite value. Defaults to 1e4.
|
||||
"""
|
||||
|
|
@ -589,7 +599,8 @@ class OnnxModel:
|
|||
|
||||
model = self.model
|
||||
if use_symbolic_shape_infer:
|
||||
# Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc) are not recognized by onnx shape inference.
|
||||
# Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc)
|
||||
# are not recognized by onnx shape inference.
|
||||
shape_infer_helper = SymbolicShapeInferenceHelper(model)
|
||||
model = shape_infer_helper.infer_shapes(model, auto_merge=True, guess_output_rank=False)
|
||||
|
||||
|
|
@ -636,7 +647,8 @@ class OnnxModel:
|
|||
if prefix in self._node_name_suffix:
|
||||
suffix = self._node_name_suffix[prefix] + 1
|
||||
else:
|
||||
# Check existed node name only once for a prefix as we assume create_node_name is called for every new node in fusion.
|
||||
# Check existed node name only once for a prefix
|
||||
# as we assume create_node_name is called for every new node in fusion.
|
||||
for node in self.nodes():
|
||||
if node.name and node.name.startswith(prefix):
|
||||
try:
|
||||
|
|
@ -734,7 +746,7 @@ class OnnxModel:
|
|||
outputs (list): a list of graph outputs to retain. If it is None, all graph outputs will be kept.
|
||||
"""
|
||||
if len(self.graphs()) > 1:
|
||||
logger.debug(f"Skip prune_graph since graph has subgraph")
|
||||
logger.debug("Skip prune_graph since graph has subgraph")
|
||||
return
|
||||
|
||||
if outputs is None:
|
||||
|
|
@ -839,7 +851,9 @@ class OnnxModel:
|
|||
for impacted_node in input_name_to_nodes[output_to_remove]:
|
||||
if impacted_node not in nodes_to_remove:
|
||||
logger.debug(
|
||||
f"it is not safe to remove nodes since output {output_to_remove} is used by {impacted_node}"
|
||||
"it is not safe to remove nodes since output %s is used by %s",
|
||||
output_to_remove,
|
||||
impacted_node,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
|
@ -960,14 +974,10 @@ class OnnxModel:
|
|||
save_model(model, output_path)
|
||||
|
||||
def save_model_to_file(self, output_path, use_external_data_format=False, all_tensors_to_one_file=True):
|
||||
logger.info(f"Sort graphs in topological order")
|
||||
logger.info("Sort graphs in topological order")
|
||||
self.topological_sort()
|
||||
|
||||
if output_path.endswith(".json"): # Output text for testing small model.
|
||||
with open(output_path, "w") as out:
|
||||
out.write(str(model))
|
||||
else:
|
||||
OnnxModel.save(self.model, output_path, use_external_data_format, all_tensors_to_one_file)
|
||||
OnnxModel.save(self.model, output_path, use_external_data_format, all_tensors_to_one_file)
|
||||
logger.info(f"Model saved to {output_path}")
|
||||
|
||||
def get_graph_inputs_excluding_initializers(self):
|
||||
|
|
|
|||
81
onnxruntime/python/tools/transformers/onnx_model_unet.py
Normal file
81
onnxruntime/python/tools/transformers/onnx_model_unet.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
from fusion_attention_unet import FusionAttentionUnet
|
||||
from fusion_options import FusionOptions
|
||||
from onnx import ModelProto
|
||||
from onnx_model_bert import BertOnnxModel
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class UnetOnnxModel(BertOnnxModel):
|
||||
def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
|
||||
"""Initialize UNet ONNX Model.
|
||||
|
||||
Args:
|
||||
model (ModelProto): the ONNX model
|
||||
num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
|
||||
hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
|
||||
"""
|
||||
assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
|
||||
|
||||
super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
|
||||
|
||||
def preprocess(self):
|
||||
return
|
||||
|
||||
def postprocess(self):
|
||||
self.prune_graph()
|
||||
|
||||
def optimize(self, options: Optional[FusionOptions] = None):
|
||||
if (options is not None) and not options.enable_shape_inference:
|
||||
self.disable_shape_inference()
|
||||
|
||||
self.utils.remove_identity_nodes()
|
||||
|
||||
# Remove cast nodes that having same data type of input and output based on symbolic shape inference.
|
||||
self.utils.remove_useless_cast_nodes()
|
||||
|
||||
if (options is None) or options.enable_layer_norm:
|
||||
self.fuse_layer_norm()
|
||||
|
||||
if (options is None) or options.enable_gelu:
|
||||
self.fuse_gelu()
|
||||
|
||||
self.preprocess()
|
||||
|
||||
self.fuse_reshape()
|
||||
|
||||
if (options is None) or options.enable_attention:
|
||||
self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False)
|
||||
self_attention_fusion.apply()
|
||||
|
||||
cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True)
|
||||
cross_attention_fusion.apply()
|
||||
|
||||
if (options is None) or options.enable_skip_layer_norm:
|
||||
self.fuse_skip_layer_norm()
|
||||
|
||||
self.fuse_shape()
|
||||
|
||||
# Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
|
||||
self.utils.remove_useless_reshape_nodes()
|
||||
|
||||
self.postprocess()
|
||||
|
||||
if (options is None) or options.enable_bias_skip_layer_norm:
|
||||
# Fuse SkipLayerNormalization and Add Bias before it.
|
||||
self.fuse_add_bias_skip_layer_norm()
|
||||
|
||||
if options is not None and options.enable_gelu_approximation:
|
||||
self.gelu_approximation()
|
||||
|
||||
self.remove_unused_constant()
|
||||
|
||||
logger.info(f"opset version: {self.get_opset_version()}")
|
||||
|
|
@ -31,6 +31,7 @@ from onnx_model_bert_keras import BertOnnxModelKeras
|
|||
from onnx_model_bert_tf import BertOnnxModelTF
|
||||
from onnx_model_gpt2 import Gpt2OnnxModel
|
||||
from onnx_model_tnlr import TnlrOnnxModel
|
||||
from onnx_model_unet import UnetOnnxModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -47,6 +48,7 @@ MODEL_TYPES = {
|
|||
0,
|
||||
), # might add a class for GPT2OnnxModel for TF later.
|
||||
"tnlr": (TnlrOnnxModel, "pytorch", 1),
|
||||
"unet": (UnetOnnxModel, "pytorch", 1),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -139,16 +141,17 @@ def optimize_by_fusion(
|
|||
model (ModelProto): model object
|
||||
model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'.
|
||||
num_heads (int, optional): number of attention heads. Defaults to 0.
|
||||
0 allows detect the parameter from graph automatically (for model_type "bert" only).
|
||||
0 allows detect the parameter from graph automatically.
|
||||
hidden_size (int, optional): hidden size. Defaults to 0.
|
||||
0 allows detect the parameter from graph automatically (for model_type "bert" only).
|
||||
optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions. Defaults to None.
|
||||
0 allows detect the parameter from graph automatically.
|
||||
optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
object of an optimizer class.
|
||||
"""
|
||||
if model_type != "bert" and (num_heads == 0 or hidden_size == 0):
|
||||
logger.warning("Please specify parameters of num_heads and hidden_size when model_type is not 'bert'")
|
||||
if model_type not in ["bert", "unet"] and (num_heads == 0 or hidden_size == 0):
|
||||
logger.warning(f"Please specify parameters of num_heads and hidden_size for model_type {model_type}")
|
||||
|
||||
(optimizer_class, producer, _) = MODEL_TYPES[model_type]
|
||||
|
||||
|
|
@ -198,7 +201,9 @@ def optimize_model(
|
|||
|
||||
When opt_level is 0 and only_onnxruntime is False, only python fusion logic is used and onnxruntime is disabled.
|
||||
|
||||
When opt_level > 1, use_gpu shall set properly since the optimized graph might contain operators for GPU or CPU only.
|
||||
When opt_level > 1, use_gpu shall set properly
|
||||
since the optimized graph might contain operators for GPU or CPU only.
|
||||
|
||||
If your model is intended for GPU inference only (especially float16 or mixed precision model), it is recommended to
|
||||
set use_gpu to be True, otherwise the model is not optimized for GPU inference.
|
||||
|
||||
|
|
@ -208,24 +213,23 @@ def optimize_model(
|
|||
input (str): input model path.
|
||||
model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'.
|
||||
num_heads (int, optional): number of attention heads. Defaults to 0.
|
||||
0 allows detect the parameter from graph automatically (for model_type "bert" only).
|
||||
0 allows detect the parameter from graph automatically.
|
||||
hidden_size (int, optional): hidden size. Defaults to 0.
|
||||
0 allows detect the parameter from graph automatically (for model_type "bert" only).
|
||||
optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions. Defaults to None.
|
||||
0 allows detect the parameter from graph automatically.
|
||||
optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions.
|
||||
Defaults to None.
|
||||
opt_level (int, optional): onnxruntime graph optimization level (0, 1, 2 or 99) or None. Defaults to None.
|
||||
When the value is None, default value (1 for bert and gpt2, 0 for other model types) will be used.
|
||||
When the level > 0, onnxruntime will be used to optimize model first.
|
||||
When the value is None, default value (1 for bert and gpt2, 0 for other model types) will be used.
|
||||
When the level > 0, onnxruntime will be used to optimize model first.
|
||||
use_gpu (bool, optional): use gpu or not for onnxruntime. Defaults to False.
|
||||
only_onnxruntime (bool, optional): only use onnxruntime to optimize model, and no python fusion. Defaults to False.
|
||||
only_onnxruntime (bool, optional): only use onnxruntime to optimize model, and no python fusion.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
object of an optimizer class.
|
||||
"""
|
||||
assert opt_level is None or opt_level in [0, 1, 2, 99]
|
||||
|
||||
if model_type != "bert" and (num_heads == 0 or hidden_size == 0):
|
||||
logger.warning("Please specify parameters of num_heads and hidden_size when model_type is not 'bert'")
|
||||
|
||||
(optimizer_class, _producer, default_opt_level) = MODEL_TYPES[model_type]
|
||||
|
||||
if opt_level is None:
|
||||
|
|
@ -300,7 +304,8 @@ def get_fusion_statistics(optimized_model_path: str) -> Dict[str, int]:
|
|||
|
||||
def _parse_arguments():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Graph optimization tool for ONNX Runtime. It transforms ONNX graph to use optimized operators for Transformer models."
|
||||
description="Graph optimization tool for ONNX Runtime."
|
||||
"It transforms ONNX graph to use optimized operators for Transformer models."
|
||||
)
|
||||
parser.add_argument("--input", required=True, type=str, help="input onnx model path")
|
||||
|
||||
|
|
@ -320,7 +325,9 @@ def _parse_arguments():
|
|||
required=False,
|
||||
type=int,
|
||||
default=0,
|
||||
help="number of attention heads like 12 for bert-base and 16 for bert-large. Default is 0 to detect automatically for BERT. For other model type, this parameter need specify correctly.",
|
||||
help="number of attention heads like 12 for bert-base and 16 for bert-large. "
|
||||
"Default is 0 to detect automatically for BERT."
|
||||
"For other model type, this parameter need specify correctly.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -328,14 +335,17 @@ def _parse_arguments():
|
|||
required=False,
|
||||
type=int,
|
||||
default=0,
|
||||
help="hidden size like 768 for bert-base and 1024 for bert-large. Default is 0 to detect automatically for BERT. For other model type, this parameter need specify correctly.",
|
||||
help="hidden size like 768 for bert-base and 1024 for bert-large. "
|
||||
"Default is 0 to detect automatically for BERT. "
|
||||
"For other model type, this parameter need specify correctly.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input_int32",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use int32 (instead of int64) inputs. It could avoid unnecessary data cast when EmbedLayerNormalization is fused for BERT.",
|
||||
help="Use int32 (instead of int64) inputs. "
|
||||
"It could avoid unnecessary data cast when EmbedLayerNormalization is fused for BERT.",
|
||||
)
|
||||
parser.set_defaults(input_int32=False)
|
||||
|
||||
|
|
@ -343,7 +353,8 @@ def _parse_arguments():
|
|||
"--float16",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Convert all weights and nodes in float32 to float16. It has potential loss in precision compared to mixed precision conversion (see convert_float_to_float16).",
|
||||
help="Convert all weights and nodes in float32 to float16. "
|
||||
"It has potential loss in precision compared to mixed precision conversion.",
|
||||
)
|
||||
parser.set_defaults(float16=False)
|
||||
|
||||
|
|
@ -374,7 +385,9 @@ def _parse_arguments():
|
|||
type=int,
|
||||
choices=[0, 1, 2, 99],
|
||||
default=None,
|
||||
help="onnxruntime optimization level. 0 will disable onnxruntime graph optimization. The recommended value is 1. When opt_level > 1 is used, optimized model for GPU might not run in CPU. Level 2 and 99 are intended for --only_onnxruntime.",
|
||||
help="onnxruntime optimization level. 0 will disable onnxruntime graph optimization. "
|
||||
"The recommended value is 1. When opt_level > 1 is used, optimized model for GPU might not run in CPU. "
|
||||
"Level 2 and 99 are intended for --only_onnxruntime.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -408,7 +421,7 @@ def main():
|
|||
logger.debug(f"arguments:{args}")
|
||||
|
||||
if os.path.realpath(args.input) == os.path.realpath(args.output):
|
||||
logger.warning(f"Specified the same input and output path. Note that this may overwrite the original model")
|
||||
logger.warning("Specified the same input and output path. Note that this may overwrite the original model")
|
||||
|
||||
optimization_options = FusionOptions.parse(args)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue