mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[QNN EP Quantization] Add fusion preprocessing to QNN quantization (#18719)
### Description - Adds graph fusions to preprocessing step that can be called before creating a QDQ model for QNN EP. - Fuse Erf sequence to Gelu (adapted from [optimizer.py](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_gelu.py)). Required by QNN EP. - Fuse ReduceMean sequence to LayerNormaliation (adapted from [optimizer.py](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_layernorm.py)). Not required by QNN EP. - Fuse ReduceL2 sequence to LpNormalization (new, specific to QNN EP). Required by QNN EP. Example use: ```python3 from quantization.execution_providers.qnn import get_qnn_qdq_config, qnn_preprocess_model # Added by this PR: model_updated = qnn_preprocess_model("model.fp32.onnx", "model.fp32.preprocessed.onnx", fuse_layernorm=True) model_to_quantize = "model.fp32.preprocessed.onnx" if model_updated else "model.fp32.onnx" # Quantize model ... qnn_config = get_qnn_qdq_config(model_to_quantize, data_reader, activation_type=QuantType.QUInt16) quantize(model_to_quantize, "model.qdq.onnx", qnn_config) ``` ### Motivation and Context Allow more models to be quantized for use with QNN EP --------- Signed-off-by: adrianlizarraga <adlizarraga@microsoft.com>
This commit is contained in:
parent
65300610e2
commit
81796a3081
10 changed files with 953 additions and 5 deletions
|
|
@ -453,6 +453,9 @@ file(GLOB onnxruntime_python_quantization_operators_src CONFIGURE_DEPENDS
|
|||
file(GLOB onnxruntime_python_quantization_cal_table_flatbuffers_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/quantization/CalTableFlatBuffers/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_quantization_fusions_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/quantization/fusions/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_quantization_ep_qnn_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/quantization/execution_providers/qnn/*.py"
|
||||
)
|
||||
|
|
@ -550,6 +553,7 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/operators
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/CalTableFlatBuffers
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/fusions
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/execution_providers
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/execution_providers/qnn
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/quantization
|
||||
|
|
@ -622,6 +626,9 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_quantization_cal_table_flatbuffers_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/CalTableFlatBuffers/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_quantization_fusions_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/fusions/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_quantization_ep_qnn_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/execution_providers/qnn/
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
from .preprocess import qnn_preprocess_model # noqa: F401
|
||||
from .quant_config import get_qnn_qdq_config # noqa: F401
|
||||
|
|
|
|||
|
|
@ -0,0 +1,127 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import onnx
|
||||
|
||||
from ...fusions import Fusion
|
||||
from ...onnx_model import ONNXModel
|
||||
|
||||
|
||||
class FusionLpNormalization(Fusion):
|
||||
def __init__(self, model: ONNXModel, epsilon: float = 1e-12):
|
||||
super().__init__(model, "LpNormalization", "ReduceL2")
|
||||
self.epsilon = epsilon
|
||||
|
||||
def fuse(
|
||||
self,
|
||||
reduce_node: onnx.NodeProto,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
):
|
||||
"""
|
||||
Interface function that tries to fuse a node sequence containing a ReduceL2 node into a single
|
||||
LpNormalization node.
|
||||
|
||||
Pattern 1:
|
||||
[root] --> ReduceL2 -----> Clip --> Expand ----> Div -->
|
||||
| (axis=-1) (min=epsilon) (shape=root) ^
|
||||
| (keepdims=True) |
|
||||
| |
|
||||
+-----------------------------------------------+
|
||||
Notes:
|
||||
- ReduceL2 must use the last axis, and keepdims == True
|
||||
- Clip must only have a min attribute that is ~1e-12
|
||||
- Expand must restore the shape to root.shape
|
||||
- The output of Expand must be the second input to Div.
|
||||
"""
|
||||
if reduce_node.output[0] not in input_name_to_nodes:
|
||||
return
|
||||
|
||||
# ReduceL2 must have one Clip child
|
||||
children = input_name_to_nodes[reduce_node.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Clip":
|
||||
return
|
||||
|
||||
# ReduceL2 must have keepdims == True
|
||||
keepdims = self.get_node_attribute(reduce_node, "keepdims")
|
||||
if not keepdims:
|
||||
return
|
||||
|
||||
# ReduceL2 axes must refer only to the last dimension.
|
||||
# Axes became an input in opset 18. Before then, axes was an attribute
|
||||
reduce_input_ttype = self.model.get_tensor_type(reduce_node.input[0])
|
||||
if not reduce_input_ttype:
|
||||
return
|
||||
|
||||
reduce_input_shape = self.tensor_shape_to_list(reduce_input_ttype)
|
||||
if not reduce_input_shape:
|
||||
return
|
||||
|
||||
axes = self.get_node_attribute(reduce_node, "axes")
|
||||
if not axes and len(reduce_node.input) > 1:
|
||||
axes = self.model.get_constant_value(reduce_node.input[1])
|
||||
|
||||
if not axes or len(axes) != 1:
|
||||
return
|
||||
|
||||
last_dim = len(reduce_input_shape) - 1
|
||||
if axes[0] != -1 and axes[0] != last_dim:
|
||||
return
|
||||
|
||||
# Clip node must have a min attribute approximately equal to 1e-12
|
||||
clip_node = children[0]
|
||||
clip_min = self.get_node_attribute(clip_node, "min")
|
||||
if clip_min is None and len(clip_node.input) > 1:
|
||||
clip_min = self.model.get_constant_value(clip_node.input[1])
|
||||
|
||||
clip_max = self.get_node_attribute(clip_node, "max") # TODO: clip_max could be FLOAT_MAX
|
||||
if clip_max is None and len(clip_node.input) > 2:
|
||||
clip_max = self.model.get_constant_value(clip_node.input[2])
|
||||
|
||||
if not (clip_max is None and clip_min is not None and clip_min > 0 and abs(clip_min - self.epsilon) < 1e-13):
|
||||
return
|
||||
|
||||
if clip_node.output[0] not in input_name_to_nodes:
|
||||
return
|
||||
|
||||
# Clip must have a single Expand child.
|
||||
children = input_name_to_nodes[clip_node.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Expand":
|
||||
return
|
||||
|
||||
expand_node = children[0]
|
||||
if expand_node.output[0] not in input_name_to_nodes:
|
||||
return
|
||||
|
||||
# Expand must have a single Div child
|
||||
children = input_name_to_nodes[expand_node.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Div":
|
||||
return
|
||||
|
||||
div_node = children[0]
|
||||
|
||||
# The first input to Div must be the root of the subgraph (i.e., reduce_node.input[0])
|
||||
# The second input to Div must be the output of the Expand.
|
||||
# As long as these two inputs go to the same Div node, then ONNX validation will ensure that
|
||||
# their shapes match.
|
||||
if div_node.input[0] != reduce_node.input[0]:
|
||||
return
|
||||
if div_node.input[1] != expand_node.output[0]:
|
||||
return
|
||||
|
||||
subgraph_input = reduce_node.input[0]
|
||||
subgraph_output = div_node.output[0]
|
||||
|
||||
subgraph_nodes = [reduce_node, clip_node, expand_node, div_node]
|
||||
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
|
||||
return
|
||||
|
||||
self.nodes_to_remove.extend(subgraph_nodes)
|
||||
fused_node = onnx.helper.make_node(
|
||||
self.fused_op_type, inputs=[subgraph_input], outputs=[subgraph_output], p=2, axis=-1
|
||||
)
|
||||
self.nodes_to_add.append(fused_node)
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
|
||||
from ...fusions import FusionGelu, FusionLayerNormalization
|
||||
from ...onnx_model import ONNXModel
|
||||
from .fusion_lpnorm import FusionLpNormalization
|
||||
|
||||
|
||||
def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: bool = False) -> bool:
|
||||
modified = False
|
||||
model = onnx.load_model(model_input)
|
||||
onnx_model = ONNXModel(model)
|
||||
|
||||
# Fuse Erf sequence into a single Gelu
|
||||
fusion_gelu = FusionGelu(onnx_model)
|
||||
if fusion_gelu.apply():
|
||||
modified = True
|
||||
|
||||
# Fuse ReduceL2 sequence into a single LpNormalization node with p == 2.
|
||||
fusion_lpnorm = FusionLpNormalization(onnx_model)
|
||||
if fusion_lpnorm.apply():
|
||||
modified = True
|
||||
|
||||
# Optionally, fuse ReduceMean sequence into a single LayerNormalization node.
|
||||
if fuse_layernorm:
|
||||
onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx")
|
||||
|
||||
# Need opset >= 17 to use LayerNormalization.
|
||||
if onnx_opset.version < 17:
|
||||
logging.warning(
|
||||
"Unable to fuse ReduceMean sequence into a LayerNormalization node. "
|
||||
"ONNX model must use an opset >= 17 in order to use LayerNormalization, "
|
||||
f"but found version {onnx_opset.version}. Please use onnx.version_converter to update your model."
|
||||
)
|
||||
else:
|
||||
fusion_layernorm = FusionLayerNormalization(onnx_model)
|
||||
if fusion_layernorm.apply():
|
||||
modified = True
|
||||
|
||||
if modified:
|
||||
onnx_model.topological_sort()
|
||||
onnx.save_model(model, model_output)
|
||||
|
||||
return modified
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from .fusion import Fusion # noqa: F401
|
||||
from .fusion_gelu import FusionGelu # noqa: F401
|
||||
from .fusion_layernorm import FusionLayerNormalization # noqa: F401
|
||||
298
onnxruntime/python/tools/quantization/fusions/fusion.py
Normal file
298
onnxruntime/python/tools/quantization/fusions/fusion.py
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
import onnx
|
||||
|
||||
from ..onnx_model import ONNXModel
|
||||
|
||||
|
||||
class Fusion:
|
||||
"""
|
||||
Base class for fusions.
|
||||
"""
|
||||
|
||||
def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str):
|
||||
self.search_op_type: str = search_op_type
|
||||
self.fused_op_type: str = fused_op_type
|
||||
self.model: ONNXModel = model
|
||||
self.nodes_to_remove: list = []
|
||||
self.nodes_to_add: list = []
|
||||
|
||||
def fuse(
|
||||
self,
|
||||
node: onnx.NodeProto,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
):
|
||||
"""
|
||||
Interface function for derived fusion classes. Tries to fuse a node sequence containing
|
||||
the specified node.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def apply(self) -> bool:
|
||||
"""
|
||||
Apply graph fusion on the entire model graph.
|
||||
"""
|
||||
input_name_to_nodes = self.model.input_name_to_nodes()
|
||||
output_name_to_node = self.model.output_name_to_node()
|
||||
|
||||
for node in self.model.nodes():
|
||||
if node.op_type == self.search_op_type:
|
||||
self.fuse(node, input_name_to_nodes, output_name_to_node)
|
||||
|
||||
self.model.remove_nodes(self.nodes_to_remove)
|
||||
self.model.add_nodes(self.nodes_to_add)
|
||||
|
||||
graph_updated = bool(self.nodes_to_remove or self.nodes_to_add)
|
||||
|
||||
if graph_updated:
|
||||
self.model.remove_unused_constant()
|
||||
|
||||
return graph_updated
|
||||
|
||||
@staticmethod
|
||||
def is_safe_to_fuse_nodes(
|
||||
nodes_to_remove: list[onnx.NodeProto],
|
||||
keep_outputs: list[str],
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
) -> bool:
|
||||
for node_to_remove in nodes_to_remove:
|
||||
for output_to_remove in node_to_remove.output:
|
||||
if output_to_remove in keep_outputs:
|
||||
continue
|
||||
|
||||
if output_to_remove in input_name_to_nodes:
|
||||
for impacted_node in input_name_to_nodes[output_to_remove]:
|
||||
if impacted_node not in nodes_to_remove:
|
||||
# Not safe to remove nodes since output is used by impacted_node
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_node_attribute(node: onnx.NodeProto, attribute_name: str):
|
||||
for attr in node.attribute:
|
||||
if attr.name == attribute_name:
|
||||
value = onnx.helper.get_attribute_value(attr)
|
||||
return value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def input_index(node_output: str, child_node: onnx.NodeProto) -> int:
|
||||
index = 0
|
||||
for input_name in child_node.input:
|
||||
if input_name == node_output:
|
||||
return index
|
||||
index += 1
|
||||
return -1
|
||||
|
||||
@staticmethod
|
||||
def tensor_shape_to_list(tensor_type) -> list[int]:
|
||||
shape_list = []
|
||||
for d in tensor_type.shape.dim:
|
||||
if d.HasField("dim_value"):
|
||||
shape_list.append(d.dim_value) # known dimension
|
||||
elif d.HasField("dim_param"):
|
||||
shape_list.append(d.dim_param) # unknown dimension with symbolic name
|
||||
else:
|
||||
shape_list.append("?") # shall not happen
|
||||
return shape_list
|
||||
|
||||
def get_constant_input(self, node: onnx.NodeProto):
|
||||
for i, inp in enumerate(node.input):
|
||||
value = self.model.get_constant_value(inp)
|
||||
if value is not None:
|
||||
return i, value
|
||||
|
||||
return None, None
|
||||
|
||||
def find_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> int:
|
||||
i, value = self.get_constant_input(node)
|
||||
if value is not None and value.size == 1 and abs(value - expected_value) < delta:
|
||||
return i
|
||||
|
||||
return -1
|
||||
|
||||
def has_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> bool:
|
||||
return self.find_constant_input(node, expected_value, delta) >= 0
|
||||
|
||||
def is_constant_with_specified_rank(self, output_name: str, rank: int) -> bool:
|
||||
value = self.model.get_constant_value(output_name)
|
||||
if value is None:
|
||||
return False # Not an initializer
|
||||
|
||||
if len(value.shape) != rank:
|
||||
return False # Wrong dimensions
|
||||
|
||||
return True
|
||||
|
||||
def match_first_parent(
|
||||
self,
|
||||
node: onnx.NodeProto,
|
||||
parent_op_type: str,
|
||||
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
|
||||
exclude: list[onnx.NodeProto] = [], # noqa: B006
|
||||
) -> tuple[onnx.NodeProto | None, int | None]:
|
||||
"""
|
||||
Find parent node based on constraints on op_type.
|
||||
|
||||
Args:
|
||||
node: current node.
|
||||
parent_op_type (str): constraint of parent node op_type.
|
||||
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
||||
exclude (list): list of nodes that are excluded (not allowed to match as parent).
|
||||
|
||||
Returns:
|
||||
parent: The matched parent node. None if not found.
|
||||
index: The input index of matched parent node. None if not found.
|
||||
"""
|
||||
if output_name_to_node is None:
|
||||
output_name_to_node = self.model.output_name_to_node()
|
||||
|
||||
for i, inp in enumerate(node.input):
|
||||
if inp in output_name_to_node:
|
||||
parent = output_name_to_node[inp]
|
||||
if parent.op_type == parent_op_type and parent not in exclude:
|
||||
return parent, i
|
||||
|
||||
return None, None
|
||||
|
||||
def match_parent(
|
||||
self,
|
||||
node: onnx.NodeProto,
|
||||
parent_op_type: str,
|
||||
input_index: int | None = None,
|
||||
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
|
||||
exclude: list[onnx.NodeProto] = [], # noqa: B006
|
||||
return_indice: list[int] | None = None,
|
||||
) -> onnx.NodeProto | None:
|
||||
"""
|
||||
Find parent node based on constraints on op_type and index.
|
||||
When input_index is None, we will find the first parent node based on constraints,
|
||||
and return_indice will be appended the corresponding input index.
|
||||
|
||||
Args:
|
||||
node (str): current node name.
|
||||
parent_op_type (str): constraint of parent node op_type.
|
||||
input_index (int or None): only check the parent given input index of current node.
|
||||
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
||||
exclude (list): list of nodes that are excluded (not allowed to match as parent).
|
||||
return_indice (list): a list to append the input index when input_index is None.
|
||||
|
||||
Returns:
|
||||
parent: The matched parent node.
|
||||
"""
|
||||
assert node is not None
|
||||
assert input_index is None or input_index >= 0
|
||||
|
||||
if output_name_to_node is None:
|
||||
output_name_to_node = self.model.output_name_to_node()
|
||||
|
||||
if input_index is None:
|
||||
parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
|
||||
if return_indice is not None:
|
||||
return_indice.append(index)
|
||||
return parent
|
||||
|
||||
if input_index >= len(node.input):
|
||||
# Input index out of bounds.
|
||||
return None
|
||||
|
||||
parent = self.model.get_parent(node, input_index, output_name_to_node)
|
||||
if parent is not None and parent.op_type == parent_op_type and parent not in exclude:
|
||||
return parent
|
||||
|
||||
return None
|
||||
|
||||
def match_parent_path(
|
||||
self,
|
||||
node: onnx.NodeProto,
|
||||
parent_op_types: list[str],
|
||||
parent_input_index: list[int] | None = None,
|
||||
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
|
||||
return_indice: list[int] | None = None,
|
||||
) -> list[onnx.NodeProto] | None:
|
||||
"""
|
||||
Find a sequence of input edges based on constraints on parent op_type and index.
|
||||
When input_index is None, we will find the first parent node based on constraints,
|
||||
and return_indice will be appended the corresponding input index.
|
||||
|
||||
Args:
|
||||
node (str): current node name.
|
||||
parent_op_types (str): constraint of parent node op_type of each input edge.
|
||||
parent_input_index (list): constraint of input index of each input edge. None means no constraint.
|
||||
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
||||
return_indice (list): a list to append the input index
|
||||
When there is no constraint on input index of an edge.
|
||||
|
||||
Returns:
|
||||
parents: a list of matched parent node.
|
||||
"""
|
||||
if parent_input_index is not None:
|
||||
assert len(parent_input_index) == len(parent_op_types)
|
||||
|
||||
if output_name_to_node is None:
|
||||
output_name_to_node = self.model.output_name_to_node()
|
||||
|
||||
current_node = node
|
||||
matched_parents = []
|
||||
for i, op_type in enumerate(parent_op_types):
|
||||
matched_parent = self.match_parent(
|
||||
current_node,
|
||||
op_type,
|
||||
parent_input_index[i] if parent_input_index is not None else None,
|
||||
output_name_to_node,
|
||||
exclude=[],
|
||||
return_indice=return_indice,
|
||||
)
|
||||
if matched_parent is None:
|
||||
return None
|
||||
|
||||
matched_parents.append(matched_parent)
|
||||
current_node = matched_parent
|
||||
|
||||
return matched_parents
|
||||
|
||||
def match_parent_paths(
|
||||
self,
|
||||
node: onnx.NodeProto,
|
||||
paths: list[tuple[list[str], list[int]]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
) -> tuple[int, list[onnx.NodeProto] | None, list[int] | None]:
|
||||
"""
|
||||
Find a matching parent path to the given node.
|
||||
"""
|
||||
for i, path in enumerate(paths):
|
||||
return_indice = []
|
||||
matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice)
|
||||
if matched:
|
||||
return i, matched, return_indice
|
||||
return -1, None, None
|
||||
|
||||
def find_first_child_by_type(
|
||||
self,
|
||||
node: onnx.NodeProto,
|
||||
child_type: str,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]] | None = None,
|
||||
recursive: bool = True,
|
||||
) -> onnx.NodeProto | None:
|
||||
children = self.model.get_children(node, input_name_to_nodes)
|
||||
dq = deque(children)
|
||||
while len(dq) > 0:
|
||||
current_node = dq.pop()
|
||||
if current_node.op_type == child_type:
|
||||
return current_node
|
||||
|
||||
if recursive:
|
||||
children = self.model.get_children(current_node, input_name_to_nodes)
|
||||
for child in children:
|
||||
dq.appendleft(child)
|
||||
|
||||
return None
|
||||
269
onnxruntime/python/tools/quantization/fusions/fusion_gelu.py
Normal file
269
onnxruntime/python/tools/quantization/fusions/fusion_gelu.py
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import onnx
|
||||
|
||||
from ..onnx_model import ONNXModel
|
||||
from .fusion import Fusion
|
||||
|
||||
|
||||
class FusionGelu(Fusion):
|
||||
def __init__(self, model: ONNXModel):
|
||||
super().__init__(model, "Gelu", "Erf")
|
||||
|
||||
def fuse(
|
||||
self,
|
||||
erf_node: onnx.NodeProto,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
):
|
||||
"""
|
||||
Interface function that tries to fuse a node sequence containing an Erf node into a single
|
||||
Gelu node.
|
||||
"""
|
||||
if (
|
||||
self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node)
|
||||
or self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node)
|
||||
or self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node)
|
||||
):
|
||||
self.model.set_opset_import("com.microsoft", 1)
|
||||
|
||||
def fuse_1(
|
||||
self,
|
||||
erf_node: onnx.NodeProto,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
) -> bool:
|
||||
"""
|
||||
This pattern is from PyTorch model
|
||||
Fuse Gelu with Erf into one node:
|
||||
Pattern 1:
|
||||
+-------Mul(0.5)---------------------+
|
||||
| |
|
||||
| v
|
||||
[root] --> Div -----> Erf --> Add --> Mul -->
|
||||
(B=1.4142...) (1)
|
||||
|
||||
Pattern 2:
|
||||
+------------------------------------+
|
||||
| |
|
||||
| v
|
||||
[root] --> Div -----> Erf --> Add --> Mul -->Mul -->
|
||||
(B=1.4142...) (1) (0.5)
|
||||
|
||||
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
|
||||
"""
|
||||
if erf_node.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[erf_node.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Add":
|
||||
return False
|
||||
add_after_erf = children[0]
|
||||
|
||||
if not self.has_constant_input(add_after_erf, 1):
|
||||
return False
|
||||
|
||||
if add_after_erf.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
|
||||
children = input_name_to_nodes[add_after_erf.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Mul":
|
||||
return False
|
||||
|
||||
mul_after_erf = children[0]
|
||||
|
||||
div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
|
||||
if div is None:
|
||||
return False
|
||||
|
||||
if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
|
||||
return False
|
||||
|
||||
subgraph_input = div.input[0]
|
||||
|
||||
another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
|
||||
if subgraph_input == mul_after_erf.input[another]: # pattern 2
|
||||
children = input_name_to_nodes[mul_after_erf.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Mul":
|
||||
return False
|
||||
mul_half = children[0]
|
||||
if not self.has_constant_input(mul_half, 0.5):
|
||||
return False
|
||||
subgraph_output = mul_half.output[0]
|
||||
else: # pattern 1
|
||||
mul_half = self.match_parent(mul_after_erf, "Mul", another, output_name_to_node)
|
||||
if mul_half is None:
|
||||
return False
|
||||
|
||||
if not self.has_constant_input(mul_half, 0.5):
|
||||
return False
|
||||
|
||||
if subgraph_input not in mul_half.input:
|
||||
return False
|
||||
|
||||
subgraph_output = mul_after_erf.output[0]
|
||||
|
||||
subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
|
||||
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
|
||||
return False
|
||||
|
||||
self.nodes_to_remove.extend(subgraph_nodes)
|
||||
fused_node = onnx.helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output])
|
||||
fused_node.domain = "com.microsoft"
|
||||
self.nodes_to_add.append(fused_node)
|
||||
return True
|
||||
|
||||
def fuse_2(
|
||||
self,
|
||||
erf_node: onnx.NodeProto,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
) -> bool:
|
||||
"""
|
||||
This pattern is from Keras model
|
||||
Fuse Gelu with Erf into one node:
|
||||
+------------------------------------------+
|
||||
| |
|
||||
| v
|
||||
[root] --> Div -----> Erf --> Add --> Mul -->Mul
|
||||
(B=1.4142...) (A=1) (A=0.5)
|
||||
|
||||
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
|
||||
"""
|
||||
if erf_node.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[erf_node.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Add":
|
||||
return False
|
||||
add_after_erf = children[0]
|
||||
|
||||
if not self.has_constant_input(add_after_erf, 1):
|
||||
return False
|
||||
|
||||
if add_after_erf.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[add_after_erf.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Mul":
|
||||
return False
|
||||
mul_after_erf = children[0]
|
||||
|
||||
if not self.has_constant_input(mul_after_erf, 0.5):
|
||||
return False
|
||||
|
||||
if mul_after_erf.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[mul_after_erf.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Mul":
|
||||
return False
|
||||
mul = children[0]
|
||||
|
||||
div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
|
||||
if div is None:
|
||||
return False
|
||||
|
||||
sqrt_node = None
|
||||
if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
|
||||
sqrt_node = self.match_parent(div, "Sqrt", 1, output_name_to_node)
|
||||
if sqrt_node is None:
|
||||
return False
|
||||
if not self.has_constant_input(sqrt_node, 2.0):
|
||||
return False
|
||||
|
||||
root_node = self.model.get_parent(div, 0, output_name_to_node)
|
||||
if root_node is None:
|
||||
return False
|
||||
|
||||
if root_node.output[0] not in mul.input:
|
||||
return False
|
||||
|
||||
subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
|
||||
if sqrt_node:
|
||||
subgraph_nodes.append(sqrt_node)
|
||||
|
||||
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node):
|
||||
return False
|
||||
|
||||
self.nodes_to_remove.extend(subgraph_nodes)
|
||||
fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]])
|
||||
fused_node.domain = "com.microsoft"
|
||||
self.nodes_to_add.append(fused_node)
|
||||
return True
|
||||
|
||||
def fuse_3(
|
||||
self,
|
||||
erf_node: onnx.NodeProto,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
) -> bool:
|
||||
"""
|
||||
This pattern is from TensorFlow model
|
||||
Fuse Gelu with Erf into one node:
|
||||
+----------------------------------------------+
|
||||
| |
|
||||
| v
|
||||
[root] --> Mul -----> Erf --> Add --> Mul -->Mul
|
||||
(A=0.7071067690849304) (B=1) (B=0.5)
|
||||
|
||||
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
|
||||
"""
|
||||
|
||||
if erf_node.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[erf_node.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Add":
|
||||
return False
|
||||
add_after_erf = children[0]
|
||||
|
||||
if not self.has_constant_input(add_after_erf, 1):
|
||||
return False
|
||||
|
||||
if add_after_erf.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[add_after_erf.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Mul":
|
||||
return False
|
||||
mul_half = children[0]
|
||||
|
||||
if not self.has_constant_input(mul_half, 0.5):
|
||||
return False
|
||||
|
||||
first_mul = self.match_parent(erf_node, "Mul", 0, output_name_to_node)
|
||||
if first_mul is None:
|
||||
return False
|
||||
|
||||
i = self.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
|
||||
if i < 0:
|
||||
return False
|
||||
|
||||
root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node)
|
||||
if root_node is None:
|
||||
return False
|
||||
|
||||
if mul_half.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[mul_half.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Mul":
|
||||
return False
|
||||
last_mul = children[0]
|
||||
|
||||
if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]):
|
||||
return False
|
||||
|
||||
subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
|
||||
if not self.is_safe_to_fuse_nodes(
|
||||
subgraph_nodes,
|
||||
[last_mul.output[0]],
|
||||
input_name_to_nodes,
|
||||
output_name_to_node,
|
||||
):
|
||||
return False
|
||||
|
||||
self.nodes_to_remove.extend(subgraph_nodes)
|
||||
fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]])
|
||||
fused_node.domain = "com.microsoft"
|
||||
self.nodes_to_add.append(fused_node)
|
||||
return True
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import onnx
|
||||
|
||||
from ..onnx_model import ONNXModel
|
||||
from .fusion import Fusion
|
||||
|
||||
|
||||
class FusionLayerNormalization(Fusion):
|
||||
def __init__(self, model: ONNXModel):
|
||||
super().__init__(model, "LayerNormalization", "ReduceMean")
|
||||
|
||||
def fuse(
|
||||
self,
|
||||
reduce_mean_node: onnx.NodeProto,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
):
|
||||
"""
|
||||
Interface function that tries to fuse a node sequence containing a ReduceMean node into a single
|
||||
LayerNormalization node.
|
||||
|
||||
+----------------------+
|
||||
| |
|
||||
| v
|
||||
[Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
|
||||
(axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^
|
||||
| |
|
||||
+-------------------------------------------------+
|
||||
|
||||
It also handles cases of duplicated sub nodes exported from older version of PyTorch:
|
||||
|
||||
+----------------------+
|
||||
| v
|
||||
| +-------> Sub-----------------------------------------------+
|
||||
| | |
|
||||
| | v
|
||||
[Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
|
||||
| ^
|
||||
| |
|
||||
+----------------------+
|
||||
"""
|
||||
children = self.model.get_children(reduce_mean_node, input_name_to_nodes)
|
||||
if len(children) == 0 or len(children) > 2:
|
||||
return
|
||||
|
||||
root_input = reduce_mean_node.input[0]
|
||||
|
||||
if children[0].op_type != "Sub" or children[0].input[0] != root_input:
|
||||
return
|
||||
|
||||
if len(children) == 2:
|
||||
if children[1].op_type != "Sub" or children[1].input[0] != root_input:
|
||||
return
|
||||
|
||||
div_node = None
|
||||
for child in children:
|
||||
div_node = self.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
|
||||
if div_node is not None:
|
||||
break
|
||||
if div_node is None:
|
||||
return
|
||||
|
||||
path_id, parent_nodes, _ = self.match_parent_paths(
|
||||
div_node,
|
||||
[
|
||||
(["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
|
||||
(
|
||||
["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"],
|
||||
[1, 0, 0, 0, 0, 0],
|
||||
),
|
||||
],
|
||||
output_name_to_node,
|
||||
)
|
||||
if path_id < 0:
|
||||
return
|
||||
|
||||
sub_node = parent_nodes[-1]
|
||||
if sub_node not in children:
|
||||
return
|
||||
|
||||
second_add_node = parent_nodes[1]
|
||||
i, add_weight = self.get_constant_input(second_add_node)
|
||||
if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
|
||||
# Skip fusion since epsilon value is not expected.
|
||||
return
|
||||
|
||||
pow_node = parent_nodes[3]
|
||||
if self.find_constant_input(pow_node, 2.0) != 1:
|
||||
return
|
||||
|
||||
mul_node = input_name_to_nodes[div_node.output[0]][0]
|
||||
if mul_node.op_type != "Mul":
|
||||
return
|
||||
|
||||
last_add_node = input_name_to_nodes[mul_node.output[0]][0]
|
||||
if last_add_node.op_type != "Add":
|
||||
return
|
||||
|
||||
subgraph_nodes = [reduce_mean_node]
|
||||
subgraph_nodes.extend(children)
|
||||
subgraph_nodes.extend(parent_nodes[:-1])
|
||||
|
||||
subgraph_nodes.extend([last_add_node, mul_node, div_node])
|
||||
if not self.is_safe_to_fuse_nodes(
|
||||
subgraph_nodes,
|
||||
last_add_node.output,
|
||||
input_name_to_nodes,
|
||||
output_name_to_node,
|
||||
):
|
||||
return
|
||||
|
||||
weight_input = mul_node.input[1 - self.input_index(div_node.output[0], mul_node)]
|
||||
if not self.is_constant_with_specified_rank(weight_input, 1):
|
||||
return
|
||||
|
||||
bias_input = last_add_node.input[1 - self.input_index(mul_node.output[0], last_add_node)]
|
||||
if not self.is_constant_with_specified_rank(bias_input, 1):
|
||||
return
|
||||
|
||||
self.nodes_to_remove.extend(subgraph_nodes)
|
||||
|
||||
normalize_node = onnx.helper.make_node(
|
||||
"LayerNormalization",
|
||||
inputs=[reduce_mean_node.input[0], weight_input, bias_input],
|
||||
outputs=[last_add_node.output[0]],
|
||||
)
|
||||
normalize_node.attribute.extend([onnx.helper.make_attribute("epsilon", float(add_weight))])
|
||||
self.nodes_to_add.append(normalize_node)
|
||||
|
|
@ -1,3 +1,7 @@
|
|||
# --------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
|
|
@ -114,6 +118,14 @@ class ONNXModel:
|
|||
def opset_import(self):
|
||||
return self.model.opset_import
|
||||
|
||||
def set_opset_import(self, domain, version):
|
||||
for opset in self.model.opset_import:
|
||||
if opset.domain == domain:
|
||||
opset.version = version
|
||||
return
|
||||
|
||||
self.model.opset_import.extend([onnx_helper.make_opsetid(domain, version)])
|
||||
|
||||
def remove_node(self, node):
|
||||
if node in self.model.graph.node:
|
||||
self.model.graph.node.remove(node)
|
||||
|
|
@ -140,6 +152,49 @@ class ONNXModel:
|
|||
return tensor
|
||||
return None
|
||||
|
||||
def find_graph_input(self, input_name):
|
||||
for input in self.model.graph.input:
|
||||
if input.name == input_name:
|
||||
return input
|
||||
return None
|
||||
|
||||
def find_graph_output(self, output_name):
|
||||
for output in self.model.graph.output:
|
||||
if output.name == output_name:
|
||||
return output
|
||||
return None
|
||||
|
||||
def get_tensor_type(self, tensor_name: str):
|
||||
tensor_type_map = {obj.name: obj.type for obj in self.model.graph.value_info}
|
||||
|
||||
if tensor_name in tensor_type_map:
|
||||
return tensor_type_map[tensor_name].tensor_type
|
||||
|
||||
g_input = self.find_graph_input(tensor_name)
|
||||
if g_input:
|
||||
return g_input.type.tensor_type
|
||||
|
||||
g_output = self.find_graph_output(tensor_name)
|
||||
if g_output:
|
||||
return g_output.type.tensor_type
|
||||
|
||||
return None
|
||||
|
||||
def get_constant_value(self, output_name):
|
||||
for node in self.model.graph.node:
|
||||
if node.op_type == "Constant":
|
||||
if node.output[0] == output_name:
|
||||
for attr in node.attribute:
|
||||
if attr.name == "value":
|
||||
return onnx_numpy_helper.to_array(attr.t)
|
||||
|
||||
# Fallback to initializer since constant folding may have been applied.
|
||||
initializer = self.get_initializer(output_name)
|
||||
if initializer is not None:
|
||||
return onnx_numpy_helper.to_array(initializer)
|
||||
|
||||
return None
|
||||
|
||||
def get_initializer_name_set(self):
|
||||
return {initializer.name for initializer in self.model.graph.initializer}
|
||||
|
||||
|
|
@ -167,17 +222,19 @@ class ONNXModel:
|
|||
input_name_to_nodes = {}
|
||||
for node in self.model.graph.node:
|
||||
for input_name in node.input:
|
||||
if input_name not in input_name_to_nodes:
|
||||
input_name_to_nodes[input_name] = [node]
|
||||
else:
|
||||
input_name_to_nodes[input_name].append(node)
|
||||
if input_name: # Could be empty when it is optional
|
||||
if input_name not in input_name_to_nodes:
|
||||
input_name_to_nodes[input_name] = [node]
|
||||
else:
|
||||
input_name_to_nodes[input_name].append(node)
|
||||
return input_name_to_nodes
|
||||
|
||||
def output_name_to_node(self):
|
||||
output_name_to_node = {}
|
||||
for node in self.model.graph.node:
|
||||
for output_name in node.output:
|
||||
output_name_to_node[output_name] = node
|
||||
if output_name: # Could be empty when it is optional
|
||||
output_name_to_node[output_name] = node
|
||||
return output_name_to_node
|
||||
|
||||
def get_children(self, node, input_name_to_nodes=None):
|
||||
|
|
|
|||
1
setup.py
1
setup.py
|
|
@ -408,6 +408,7 @@ packages = [
|
|||
"onnxruntime.quantization",
|
||||
"onnxruntime.quantization.operators",
|
||||
"onnxruntime.quantization.CalTableFlatBuffers",
|
||||
"onnxruntime.quantization.fusions",
|
||||
"onnxruntime.quantization.execution_providers.qnn",
|
||||
"onnxruntime.transformers",
|
||||
"onnxruntime.transformers.models.bart",
|
||||
|
|
|
|||
Loading…
Reference in a new issue