[QNN EP Quantization] Add fusion preprocessing to QNN quantization (#18719)

### Description
- Adds graph fusions to preprocessing step that can be called before
creating a QDQ model for QNN EP.
- Fuse Erf sequence to Gelu (adapted from
[optimizer.py](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_gelu.py)).
Required by QNN EP.
- Fuse ReduceMean sequence to LayerNormaliation (adapted from
[optimizer.py](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_layernorm.py)).
Not required by QNN EP.
- Fuse ReduceL2 sequence to LpNormalization (new, specific to QNN EP).
Required by QNN EP.

Example use:
```python3
from quantization.execution_providers.qnn import get_qnn_qdq_config, qnn_preprocess_model

# Added by this PR:
model_updated = qnn_preprocess_model("model.fp32.onnx", "model.fp32.preprocessed.onnx", fuse_layernorm=True)
model_to_quantize = "model.fp32.preprocessed.onnx" if model_updated else "model.fp32.onnx"

# Quantize model ...
qnn_config = get_qnn_qdq_config(model_to_quantize, data_reader, activation_type=QuantType.QUInt16)
quantize(model_to_quantize, "model.qdq.onnx", qnn_config)
```
### Motivation and Context
Allow more models to be quantized for use with QNN EP

---------

Signed-off-by: adrianlizarraga <adlizarraga@microsoft.com>
This commit is contained in:
Adrian Lizarraga 2023-12-12 08:43:04 -08:00 committed by GitHub
parent 65300610e2
commit 81796a3081
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 953 additions and 5 deletions

View file

@ -453,6 +453,9 @@ file(GLOB onnxruntime_python_quantization_operators_src CONFIGURE_DEPENDS
file(GLOB onnxruntime_python_quantization_cal_table_flatbuffers_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/quantization/CalTableFlatBuffers/*.py"
)
file(GLOB onnxruntime_python_quantization_fusions_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/quantization/fusions/*.py"
)
file(GLOB onnxruntime_python_quantization_ep_qnn_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/quantization/execution_providers/qnn/*.py"
)
@ -550,6 +553,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/operators
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/CalTableFlatBuffers
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/fusions
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/execution_providers
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/execution_providers/qnn
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/quantization
@ -622,6 +626,9 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_quantization_cal_table_flatbuffers_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/CalTableFlatBuffers/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_quantization_fusions_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/fusions/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_quantization_ep_qnn_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/execution_providers/qnn/

View file

@ -1 +1,2 @@
from .preprocess import qnn_preprocess_model # noqa: F401
from .quant_config import get_qnn_qdq_config # noqa: F401

View file

@ -0,0 +1,127 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations
import onnx
from ...fusions import Fusion
from ...onnx_model import ONNXModel
class FusionLpNormalization(Fusion):
def __init__(self, model: ONNXModel, epsilon: float = 1e-12):
super().__init__(model, "LpNormalization", "ReduceL2")
self.epsilon = epsilon
def fuse(
self,
reduce_node: onnx.NodeProto,
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
):
"""
Interface function that tries to fuse a node sequence containing a ReduceL2 node into a single
LpNormalization node.
Pattern 1:
[root] --> ReduceL2 -----> Clip --> Expand ----> Div -->
| (axis=-1) (min=epsilon) (shape=root) ^
| (keepdims=True) |
| |
+-----------------------------------------------+
Notes:
- ReduceL2 must use the last axis, and keepdims == True
- Clip must only have a min attribute that is ~1e-12
- Expand must restore the shape to root.shape
- The output of Expand must be the second input to Div.
"""
if reduce_node.output[0] not in input_name_to_nodes:
return
# ReduceL2 must have one Clip child
children = input_name_to_nodes[reduce_node.output[0]]
if len(children) != 1 or children[0].op_type != "Clip":
return
# ReduceL2 must have keepdims == True
keepdims = self.get_node_attribute(reduce_node, "keepdims")
if not keepdims:
return
# ReduceL2 axes must refer only to the last dimension.
# Axes became an input in opset 18. Before then, axes was an attribute
reduce_input_ttype = self.model.get_tensor_type(reduce_node.input[0])
if not reduce_input_ttype:
return
reduce_input_shape = self.tensor_shape_to_list(reduce_input_ttype)
if not reduce_input_shape:
return
axes = self.get_node_attribute(reduce_node, "axes")
if not axes and len(reduce_node.input) > 1:
axes = self.model.get_constant_value(reduce_node.input[1])
if not axes or len(axes) != 1:
return
last_dim = len(reduce_input_shape) - 1
if axes[0] != -1 and axes[0] != last_dim:
return
# Clip node must have a min attribute approximately equal to 1e-12
clip_node = children[0]
clip_min = self.get_node_attribute(clip_node, "min")
if clip_min is None and len(clip_node.input) > 1:
clip_min = self.model.get_constant_value(clip_node.input[1])
clip_max = self.get_node_attribute(clip_node, "max") # TODO: clip_max could be FLOAT_MAX
if clip_max is None and len(clip_node.input) > 2:
clip_max = self.model.get_constant_value(clip_node.input[2])
if not (clip_max is None and clip_min is not None and clip_min > 0 and abs(clip_min - self.epsilon) < 1e-13):
return
if clip_node.output[0] not in input_name_to_nodes:
return
# Clip must have a single Expand child.
children = input_name_to_nodes[clip_node.output[0]]
if len(children) != 1 or children[0].op_type != "Expand":
return
expand_node = children[0]
if expand_node.output[0] not in input_name_to_nodes:
return
# Expand must have a single Div child
children = input_name_to_nodes[expand_node.output[0]]
if len(children) != 1 or children[0].op_type != "Div":
return
div_node = children[0]
# The first input to Div must be the root of the subgraph (i.e., reduce_node.input[0])
# The second input to Div must be the output of the Expand.
# As long as these two inputs go to the same Div node, then ONNX validation will ensure that
# their shapes match.
if div_node.input[0] != reduce_node.input[0]:
return
if div_node.input[1] != expand_node.output[0]:
return
subgraph_input = reduce_node.input[0]
subgraph_output = div_node.output[0]
subgraph_nodes = [reduce_node, clip_node, expand_node, div_node]
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
return
self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node(
self.fused_op_type, inputs=[subgraph_input], outputs=[subgraph_output], p=2, axis=-1
)
self.nodes_to_add.append(fused_node)

View file

@ -0,0 +1,51 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
from pathlib import Path
import onnx
from ...fusions import FusionGelu, FusionLayerNormalization
from ...onnx_model import ONNXModel
from .fusion_lpnorm import FusionLpNormalization
def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: bool = False) -> bool:
modified = False
model = onnx.load_model(model_input)
onnx_model = ONNXModel(model)
# Fuse Erf sequence into a single Gelu
fusion_gelu = FusionGelu(onnx_model)
if fusion_gelu.apply():
modified = True
# Fuse ReduceL2 sequence into a single LpNormalization node with p == 2.
fusion_lpnorm = FusionLpNormalization(onnx_model)
if fusion_lpnorm.apply():
modified = True
# Optionally, fuse ReduceMean sequence into a single LayerNormalization node.
if fuse_layernorm:
onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx")
# Need opset >= 17 to use LayerNormalization.
if onnx_opset.version < 17:
logging.warning(
"Unable to fuse ReduceMean sequence into a LayerNormalization node. "
"ONNX model must use an opset >= 17 in order to use LayerNormalization, "
f"but found version {onnx_opset.version}. Please use onnx.version_converter to update your model."
)
else:
fusion_layernorm = FusionLayerNormalization(onnx_model)
if fusion_layernorm.apply():
modified = True
if modified:
onnx_model.topological_sort()
onnx.save_model(model, model_output)
return modified

View file

@ -0,0 +1,3 @@
from .fusion import Fusion # noqa: F401
from .fusion_gelu import FusionGelu # noqa: F401
from .fusion_layernorm import FusionLayerNormalization # noqa: F401

View file

@ -0,0 +1,298 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations
from collections import deque
import onnx
from ..onnx_model import ONNXModel
class Fusion:
"""
Base class for fusions.
"""
def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str):
self.search_op_type: str = search_op_type
self.fused_op_type: str = fused_op_type
self.model: ONNXModel = model
self.nodes_to_remove: list = []
self.nodes_to_add: list = []
def fuse(
self,
node: onnx.NodeProto,
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
):
"""
Interface function for derived fusion classes. Tries to fuse a node sequence containing
the specified node.
"""
raise NotImplementedError
def apply(self) -> bool:
"""
Apply graph fusion on the entire model graph.
"""
input_name_to_nodes = self.model.input_name_to_nodes()
output_name_to_node = self.model.output_name_to_node()
for node in self.model.nodes():
if node.op_type == self.search_op_type:
self.fuse(node, input_name_to_nodes, output_name_to_node)
self.model.remove_nodes(self.nodes_to_remove)
self.model.add_nodes(self.nodes_to_add)
graph_updated = bool(self.nodes_to_remove or self.nodes_to_add)
if graph_updated:
self.model.remove_unused_constant()
return graph_updated
@staticmethod
def is_safe_to_fuse_nodes(
nodes_to_remove: list[onnx.NodeProto],
keep_outputs: list[str],
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
) -> bool:
for node_to_remove in nodes_to_remove:
for output_to_remove in node_to_remove.output:
if output_to_remove in keep_outputs:
continue
if output_to_remove in input_name_to_nodes:
for impacted_node in input_name_to_nodes[output_to_remove]:
if impacted_node not in nodes_to_remove:
# Not safe to remove nodes since output is used by impacted_node
return False
return True
@staticmethod
def get_node_attribute(node: onnx.NodeProto, attribute_name: str):
for attr in node.attribute:
if attr.name == attribute_name:
value = onnx.helper.get_attribute_value(attr)
return value
return None
@staticmethod
def input_index(node_output: str, child_node: onnx.NodeProto) -> int:
index = 0
for input_name in child_node.input:
if input_name == node_output:
return index
index += 1
return -1
@staticmethod
def tensor_shape_to_list(tensor_type) -> list[int]:
shape_list = []
for d in tensor_type.shape.dim:
if d.HasField("dim_value"):
shape_list.append(d.dim_value) # known dimension
elif d.HasField("dim_param"):
shape_list.append(d.dim_param) # unknown dimension with symbolic name
else:
shape_list.append("?") # shall not happen
return shape_list
def get_constant_input(self, node: onnx.NodeProto):
for i, inp in enumerate(node.input):
value = self.model.get_constant_value(inp)
if value is not None:
return i, value
return None, None
def find_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> int:
i, value = self.get_constant_input(node)
if value is not None and value.size == 1 and abs(value - expected_value) < delta:
return i
return -1
def has_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> bool:
return self.find_constant_input(node, expected_value, delta) >= 0
def is_constant_with_specified_rank(self, output_name: str, rank: int) -> bool:
value = self.model.get_constant_value(output_name)
if value is None:
return False # Not an initializer
if len(value.shape) != rank:
return False # Wrong dimensions
return True
def match_first_parent(
self,
node: onnx.NodeProto,
parent_op_type: str,
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
exclude: list[onnx.NodeProto] = [], # noqa: B006
) -> tuple[onnx.NodeProto | None, int | None]:
"""
Find parent node based on constraints on op_type.
Args:
node: current node.
parent_op_type (str): constraint of parent node op_type.
output_name_to_node (dict): dictionary with output name as key, and node as value.
exclude (list): list of nodes that are excluded (not allowed to match as parent).
Returns:
parent: The matched parent node. None if not found.
index: The input index of matched parent node. None if not found.
"""
if output_name_to_node is None:
output_name_to_node = self.model.output_name_to_node()
for i, inp in enumerate(node.input):
if inp in output_name_to_node:
parent = output_name_to_node[inp]
if parent.op_type == parent_op_type and parent not in exclude:
return parent, i
return None, None
def match_parent(
self,
node: onnx.NodeProto,
parent_op_type: str,
input_index: int | None = None,
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
exclude: list[onnx.NodeProto] = [], # noqa: B006
return_indice: list[int] | None = None,
) -> onnx.NodeProto | None:
"""
Find parent node based on constraints on op_type and index.
When input_index is None, we will find the first parent node based on constraints,
and return_indice will be appended the corresponding input index.
Args:
node (str): current node name.
parent_op_type (str): constraint of parent node op_type.
input_index (int or None): only check the parent given input index of current node.
output_name_to_node (dict): dictionary with output name as key, and node as value.
exclude (list): list of nodes that are excluded (not allowed to match as parent).
return_indice (list): a list to append the input index when input_index is None.
Returns:
parent: The matched parent node.
"""
assert node is not None
assert input_index is None or input_index >= 0
if output_name_to_node is None:
output_name_to_node = self.model.output_name_to_node()
if input_index is None:
parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
if return_indice is not None:
return_indice.append(index)
return parent
if input_index >= len(node.input):
# Input index out of bounds.
return None
parent = self.model.get_parent(node, input_index, output_name_to_node)
if parent is not None and parent.op_type == parent_op_type and parent not in exclude:
return parent
return None
def match_parent_path(
self,
node: onnx.NodeProto,
parent_op_types: list[str],
parent_input_index: list[int] | None = None,
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
return_indice: list[int] | None = None,
) -> list[onnx.NodeProto] | None:
"""
Find a sequence of input edges based on constraints on parent op_type and index.
When input_index is None, we will find the first parent node based on constraints,
and return_indice will be appended the corresponding input index.
Args:
node (str): current node name.
parent_op_types (str): constraint of parent node op_type of each input edge.
parent_input_index (list): constraint of input index of each input edge. None means no constraint.
output_name_to_node (dict): dictionary with output name as key, and node as value.
return_indice (list): a list to append the input index
When there is no constraint on input index of an edge.
Returns:
parents: a list of matched parent node.
"""
if parent_input_index is not None:
assert len(parent_input_index) == len(parent_op_types)
if output_name_to_node is None:
output_name_to_node = self.model.output_name_to_node()
current_node = node
matched_parents = []
for i, op_type in enumerate(parent_op_types):
matched_parent = self.match_parent(
current_node,
op_type,
parent_input_index[i] if parent_input_index is not None else None,
output_name_to_node,
exclude=[],
return_indice=return_indice,
)
if matched_parent is None:
return None
matched_parents.append(matched_parent)
current_node = matched_parent
return matched_parents
def match_parent_paths(
self,
node: onnx.NodeProto,
paths: list[tuple[list[str], list[int]]],
output_name_to_node: dict[str, onnx.NodeProto],
) -> tuple[int, list[onnx.NodeProto] | None, list[int] | None]:
"""
Find a matching parent path to the given node.
"""
for i, path in enumerate(paths):
return_indice = []
matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice)
if matched:
return i, matched, return_indice
return -1, None, None
def find_first_child_by_type(
self,
node: onnx.NodeProto,
child_type: str,
input_name_to_nodes: dict[str, list[onnx.NodeProto]] | None = None,
recursive: bool = True,
) -> onnx.NodeProto | None:
children = self.model.get_children(node, input_name_to_nodes)
dq = deque(children)
while len(dq) > 0:
current_node = dq.pop()
if current_node.op_type == child_type:
return current_node
if recursive:
children = self.model.get_children(current_node, input_name_to_nodes)
for child in children:
dq.appendleft(child)
return None

View file

@ -0,0 +1,269 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations
import onnx
from ..onnx_model import ONNXModel
from .fusion import Fusion
class FusionGelu(Fusion):
def __init__(self, model: ONNXModel):
super().__init__(model, "Gelu", "Erf")
def fuse(
self,
erf_node: onnx.NodeProto,
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
):
"""
Interface function that tries to fuse a node sequence containing an Erf node into a single
Gelu node.
"""
if (
self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node)
or self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node)
or self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node)
):
self.model.set_opset_import("com.microsoft", 1)
def fuse_1(
self,
erf_node: onnx.NodeProto,
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
) -> bool:
"""
This pattern is from PyTorch model
Fuse Gelu with Erf into one node:
Pattern 1:
+-------Mul(0.5)---------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->
(B=1.4142...) (1)
Pattern 2:
+------------------------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->Mul -->
(B=1.4142...) (1) (0.5)
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
"""
if erf_node.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[erf_node.output[0]]
if len(children) != 1 or children[0].op_type != "Add":
return False
add_after_erf = children[0]
if not self.has_constant_input(add_after_erf, 1):
return False
if add_after_erf.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[add_after_erf.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return False
mul_after_erf = children[0]
div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
if div is None:
return False
if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
return False
subgraph_input = div.input[0]
another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
if subgraph_input == mul_after_erf.input[another]: # pattern 2
children = input_name_to_nodes[mul_after_erf.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return False
mul_half = children[0]
if not self.has_constant_input(mul_half, 0.5):
return False
subgraph_output = mul_half.output[0]
else: # pattern 1
mul_half = self.match_parent(mul_after_erf, "Mul", another, output_name_to_node)
if mul_half is None:
return False
if not self.has_constant_input(mul_half, 0.5):
return False
if subgraph_input not in mul_half.input:
return False
subgraph_output = mul_after_erf.output[0]
subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
return False
self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output])
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
return True
def fuse_2(
self,
erf_node: onnx.NodeProto,
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
) -> bool:
"""
This pattern is from Keras model
Fuse Gelu with Erf into one node:
+------------------------------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->Mul
(B=1.4142...) (A=1) (A=0.5)
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
"""
if erf_node.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[erf_node.output[0]]
if len(children) != 1 or children[0].op_type != "Add":
return False
add_after_erf = children[0]
if not self.has_constant_input(add_after_erf, 1):
return False
if add_after_erf.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[add_after_erf.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return False
mul_after_erf = children[0]
if not self.has_constant_input(mul_after_erf, 0.5):
return False
if mul_after_erf.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[mul_after_erf.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return False
mul = children[0]
div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
if div is None:
return False
sqrt_node = None
if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
sqrt_node = self.match_parent(div, "Sqrt", 1, output_name_to_node)
if sqrt_node is None:
return False
if not self.has_constant_input(sqrt_node, 2.0):
return False
root_node = self.model.get_parent(div, 0, output_name_to_node)
if root_node is None:
return False
if root_node.output[0] not in mul.input:
return False
subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
if sqrt_node:
subgraph_nodes.append(sqrt_node)
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node):
return False
self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]])
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
return True
def fuse_3(
self,
erf_node: onnx.NodeProto,
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
) -> bool:
"""
This pattern is from TensorFlow model
Fuse Gelu with Erf into one node:
+----------------------------------------------+
| |
| v
[root] --> Mul -----> Erf --> Add --> Mul -->Mul
(A=0.7071067690849304) (B=1) (B=0.5)
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
"""
if erf_node.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[erf_node.output[0]]
if len(children) != 1 or children[0].op_type != "Add":
return False
add_after_erf = children[0]
if not self.has_constant_input(add_after_erf, 1):
return False
if add_after_erf.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[add_after_erf.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return False
mul_half = children[0]
if not self.has_constant_input(mul_half, 0.5):
return False
first_mul = self.match_parent(erf_node, "Mul", 0, output_name_to_node)
if first_mul is None:
return False
i = self.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
if i < 0:
return False
root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node)
if root_node is None:
return False
if mul_half.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[mul_half.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return False
last_mul = children[0]
if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]):
return False
subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
if not self.is_safe_to_fuse_nodes(
subgraph_nodes,
[last_mul.output[0]],
input_name_to_nodes,
output_name_to_node,
):
return False
self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]])
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
return True

View file

@ -0,0 +1,134 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations
import onnx
from ..onnx_model import ONNXModel
from .fusion import Fusion
class FusionLayerNormalization(Fusion):
def __init__(self, model: ONNXModel):
super().__init__(model, "LayerNormalization", "ReduceMean")
def fuse(
self,
reduce_mean_node: onnx.NodeProto,
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
):
"""
Interface function that tries to fuse a node sequence containing a ReduceMean node into a single
LayerNormalization node.
+----------------------+
| |
| v
[Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
(axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^
| |
+-------------------------------------------------+
It also handles cases of duplicated sub nodes exported from older version of PyTorch:
+----------------------+
| v
| +-------> Sub-----------------------------------------------+
| | |
| | v
[Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
| ^
| |
+----------------------+
"""
children = self.model.get_children(reduce_mean_node, input_name_to_nodes)
if len(children) == 0 or len(children) > 2:
return
root_input = reduce_mean_node.input[0]
if children[0].op_type != "Sub" or children[0].input[0] != root_input:
return
if len(children) == 2:
if children[1].op_type != "Sub" or children[1].input[0] != root_input:
return
div_node = None
for child in children:
div_node = self.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
if div_node is not None:
break
if div_node is None:
return
path_id, parent_nodes, _ = self.match_parent_paths(
div_node,
[
(["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
(
["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"],
[1, 0, 0, 0, 0, 0],
),
],
output_name_to_node,
)
if path_id < 0:
return
sub_node = parent_nodes[-1]
if sub_node not in children:
return
second_add_node = parent_nodes[1]
i, add_weight = self.get_constant_input(second_add_node)
if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
# Skip fusion since epsilon value is not expected.
return
pow_node = parent_nodes[3]
if self.find_constant_input(pow_node, 2.0) != 1:
return
mul_node = input_name_to_nodes[div_node.output[0]][0]
if mul_node.op_type != "Mul":
return
last_add_node = input_name_to_nodes[mul_node.output[0]][0]
if last_add_node.op_type != "Add":
return
subgraph_nodes = [reduce_mean_node]
subgraph_nodes.extend(children)
subgraph_nodes.extend(parent_nodes[:-1])
subgraph_nodes.extend([last_add_node, mul_node, div_node])
if not self.is_safe_to_fuse_nodes(
subgraph_nodes,
last_add_node.output,
input_name_to_nodes,
output_name_to_node,
):
return
weight_input = mul_node.input[1 - self.input_index(div_node.output[0], mul_node)]
if not self.is_constant_with_specified_rank(weight_input, 1):
return
bias_input = last_add_node.input[1 - self.input_index(mul_node.output[0], last_add_node)]
if not self.is_constant_with_specified_rank(bias_input, 1):
return
self.nodes_to_remove.extend(subgraph_nodes)
normalize_node = onnx.helper.make_node(
"LayerNormalization",
inputs=[reduce_mean_node.input[0], weight_input, bias_input],
outputs=[last_add_node.output[0]],
)
normalize_node.attribute.extend([onnx.helper.make_attribute("epsilon", float(add_weight))])
self.nodes_to_add.append(normalize_node)

View file

@ -1,3 +1,7 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from pathlib import Path
import onnx
@ -114,6 +118,14 @@ class ONNXModel:
def opset_import(self):
return self.model.opset_import
def set_opset_import(self, domain, version):
for opset in self.model.opset_import:
if opset.domain == domain:
opset.version = version
return
self.model.opset_import.extend([onnx_helper.make_opsetid(domain, version)])
def remove_node(self, node):
if node in self.model.graph.node:
self.model.graph.node.remove(node)
@ -140,6 +152,49 @@ class ONNXModel:
return tensor
return None
def find_graph_input(self, input_name):
for input in self.model.graph.input:
if input.name == input_name:
return input
return None
def find_graph_output(self, output_name):
for output in self.model.graph.output:
if output.name == output_name:
return output
return None
def get_tensor_type(self, tensor_name: str):
tensor_type_map = {obj.name: obj.type for obj in self.model.graph.value_info}
if tensor_name in tensor_type_map:
return tensor_type_map[tensor_name].tensor_type
g_input = self.find_graph_input(tensor_name)
if g_input:
return g_input.type.tensor_type
g_output = self.find_graph_output(tensor_name)
if g_output:
return g_output.type.tensor_type
return None
def get_constant_value(self, output_name):
for node in self.model.graph.node:
if node.op_type == "Constant":
if node.output[0] == output_name:
for attr in node.attribute:
if attr.name == "value":
return onnx_numpy_helper.to_array(attr.t)
# Fallback to initializer since constant folding may have been applied.
initializer = self.get_initializer(output_name)
if initializer is not None:
return onnx_numpy_helper.to_array(initializer)
return None
def get_initializer_name_set(self):
return {initializer.name for initializer in self.model.graph.initializer}
@ -167,17 +222,19 @@ class ONNXModel:
input_name_to_nodes = {}
for node in self.model.graph.node:
for input_name in node.input:
if input_name not in input_name_to_nodes:
input_name_to_nodes[input_name] = [node]
else:
input_name_to_nodes[input_name].append(node)
if input_name: # Could be empty when it is optional
if input_name not in input_name_to_nodes:
input_name_to_nodes[input_name] = [node]
else:
input_name_to_nodes[input_name].append(node)
return input_name_to_nodes
def output_name_to_node(self):
output_name_to_node = {}
for node in self.model.graph.node:
for output_name in node.output:
output_name_to_node[output_name] = node
if output_name: # Could be empty when it is optional
output_name_to_node[output_name] = node
return output_name_to_node
def get_children(self, node, input_name_to_nodes=None):

View file

@ -408,6 +408,7 @@ packages = [
"onnxruntime.quantization",
"onnxruntime.quantization.operators",
"onnxruntime.quantization.CalTableFlatBuffers",
"onnxruntime.quantization.fusions",
"onnxruntime.quantization.execution_providers.qnn",
"onnxruntime.transformers",
"onnxruntime.transformers.models.bart",