From 81796a30810ca9038474260742e542fffa11fc71 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 12 Dec 2023 08:43:04 -0800 Subject: [PATCH] [QNN EP Quantization] Add fusion preprocessing to QNN quantization (#18719) ### Description - Adds graph fusions to preprocessing step that can be called before creating a QDQ model for QNN EP. - Fuse Erf sequence to Gelu (adapted from [optimizer.py](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_gelu.py)). Required by QNN EP. - Fuse ReduceMean sequence to LayerNormaliation (adapted from [optimizer.py](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_layernorm.py)). Not required by QNN EP. - Fuse ReduceL2 sequence to LpNormalization (new, specific to QNN EP). Required by QNN EP. Example use: ```python3 from quantization.execution_providers.qnn import get_qnn_qdq_config, qnn_preprocess_model # Added by this PR: model_updated = qnn_preprocess_model("model.fp32.onnx", "model.fp32.preprocessed.onnx", fuse_layernorm=True) model_to_quantize = "model.fp32.preprocessed.onnx" if model_updated else "model.fp32.onnx" # Quantize model ... qnn_config = get_qnn_qdq_config(model_to_quantize, data_reader, activation_type=QuantType.QUInt16) quantize(model_to_quantize, "model.qdq.onnx", qnn_config) ``` ### Motivation and Context Allow more models to be quantized for use with QNN EP --------- Signed-off-by: adrianlizarraga --- cmake/onnxruntime_python.cmake | 7 + .../execution_providers/qnn/__init__.py | 1 + .../execution_providers/qnn/fusion_lpnorm.py | 127 ++++++++ .../execution_providers/qnn/preprocess.py | 51 +++ .../tools/quantization/fusions/__init__.py | 3 + .../tools/quantization/fusions/fusion.py | 298 ++++++++++++++++++ .../tools/quantization/fusions/fusion_gelu.py | 269 ++++++++++++++++ .../quantization/fusions/fusion_layernorm.py | 134 ++++++++ .../python/tools/quantization/onnx_model.py | 67 +++- setup.py | 1 + 10 files changed, 953 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py create mode 100644 onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py create mode 100644 onnxruntime/python/tools/quantization/fusions/__init__.py create mode 100644 onnxruntime/python/tools/quantization/fusions/fusion.py create mode 100644 onnxruntime/python/tools/quantization/fusions/fusion_gelu.py create mode 100644 onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index b93ccf77d5..6192296158 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -453,6 +453,9 @@ file(GLOB onnxruntime_python_quantization_operators_src CONFIGURE_DEPENDS file(GLOB onnxruntime_python_quantization_cal_table_flatbuffers_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/quantization/CalTableFlatBuffers/*.py" ) +file(GLOB onnxruntime_python_quantization_fusions_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/quantization/fusions/*.py" +) file(GLOB onnxruntime_python_quantization_ep_qnn_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/quantization/execution_providers/qnn/*.py" ) @@ -550,6 +553,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/operators COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/CalTableFlatBuffers + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/fusions COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/execution_providers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/execution_providers/qnn COMMAND ${CMAKE_COMMAND} -E make_directory $/quantization @@ -622,6 +626,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_quantization_cal_table_flatbuffers_src} $/onnxruntime/quantization/CalTableFlatBuffers/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_quantization_fusions_src} + $/onnxruntime/quantization/fusions/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_quantization_ep_qnn_src} $/onnxruntime/quantization/execution_providers/qnn/ diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/__init__.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/__init__.py index c5f0b27f75..61a264c275 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/__init__.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/__init__.py @@ -1 +1,2 @@ +from .preprocess import qnn_preprocess_model # noqa: F401 from .quant_config import get_qnn_qdq_config # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py new file mode 100644 index 0000000000..9ebf400498 --- /dev/null +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py @@ -0,0 +1,127 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import onnx + +from ...fusions import Fusion +from ...onnx_model import ONNXModel + + +class FusionLpNormalization(Fusion): + def __init__(self, model: ONNXModel, epsilon: float = 1e-12): + super().__init__(model, "LpNormalization", "ReduceL2") + self.epsilon = epsilon + + def fuse( + self, + reduce_node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ): + """ + Interface function that tries to fuse a node sequence containing a ReduceL2 node into a single + LpNormalization node. + + Pattern 1: + [root] --> ReduceL2 -----> Clip --> Expand ----> Div --> + | (axis=-1) (min=epsilon) (shape=root) ^ + | (keepdims=True) | + | | + +-----------------------------------------------+ + Notes: + - ReduceL2 must use the last axis, and keepdims == True + - Clip must only have a min attribute that is ~1e-12 + - Expand must restore the shape to root.shape + - The output of Expand must be the second input to Div. + """ + if reduce_node.output[0] not in input_name_to_nodes: + return + + # ReduceL2 must have one Clip child + children = input_name_to_nodes[reduce_node.output[0]] + if len(children) != 1 or children[0].op_type != "Clip": + return + + # ReduceL2 must have keepdims == True + keepdims = self.get_node_attribute(reduce_node, "keepdims") + if not keepdims: + return + + # ReduceL2 axes must refer only to the last dimension. + # Axes became an input in opset 18. Before then, axes was an attribute + reduce_input_ttype = self.model.get_tensor_type(reduce_node.input[0]) + if not reduce_input_ttype: + return + + reduce_input_shape = self.tensor_shape_to_list(reduce_input_ttype) + if not reduce_input_shape: + return + + axes = self.get_node_attribute(reduce_node, "axes") + if not axes and len(reduce_node.input) > 1: + axes = self.model.get_constant_value(reduce_node.input[1]) + + if not axes or len(axes) != 1: + return + + last_dim = len(reduce_input_shape) - 1 + if axes[0] != -1 and axes[0] != last_dim: + return + + # Clip node must have a min attribute approximately equal to 1e-12 + clip_node = children[0] + clip_min = self.get_node_attribute(clip_node, "min") + if clip_min is None and len(clip_node.input) > 1: + clip_min = self.model.get_constant_value(clip_node.input[1]) + + clip_max = self.get_node_attribute(clip_node, "max") # TODO: clip_max could be FLOAT_MAX + if clip_max is None and len(clip_node.input) > 2: + clip_max = self.model.get_constant_value(clip_node.input[2]) + + if not (clip_max is None and clip_min is not None and clip_min > 0 and abs(clip_min - self.epsilon) < 1e-13): + return + + if clip_node.output[0] not in input_name_to_nodes: + return + + # Clip must have a single Expand child. + children = input_name_to_nodes[clip_node.output[0]] + if len(children) != 1 or children[0].op_type != "Expand": + return + + expand_node = children[0] + if expand_node.output[0] not in input_name_to_nodes: + return + + # Expand must have a single Div child + children = input_name_to_nodes[expand_node.output[0]] + if len(children) != 1 or children[0].op_type != "Div": + return + + div_node = children[0] + + # The first input to Div must be the root of the subgraph (i.e., reduce_node.input[0]) + # The second input to Div must be the output of the Expand. + # As long as these two inputs go to the same Div node, then ONNX validation will ensure that + # their shapes match. + if div_node.input[0] != reduce_node.input[0]: + return + if div_node.input[1] != expand_node.output[0]: + return + + subgraph_input = reduce_node.input[0] + subgraph_output = div_node.output[0] + + subgraph_nodes = [reduce_node, clip_node, expand_node, div_node] + if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node): + return + + self.nodes_to_remove.extend(subgraph_nodes) + fused_node = onnx.helper.make_node( + self.fused_op_type, inputs=[subgraph_input], outputs=[subgraph_output], p=2, axis=-1 + ) + self.nodes_to_add.append(fused_node) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py new file mode 100644 index 0000000000..becbaceab1 --- /dev/null +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -0,0 +1,51 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import logging +from pathlib import Path + +import onnx + +from ...fusions import FusionGelu, FusionLayerNormalization +from ...onnx_model import ONNXModel +from .fusion_lpnorm import FusionLpNormalization + + +def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: bool = False) -> bool: + modified = False + model = onnx.load_model(model_input) + onnx_model = ONNXModel(model) + + # Fuse Erf sequence into a single Gelu + fusion_gelu = FusionGelu(onnx_model) + if fusion_gelu.apply(): + modified = True + + # Fuse ReduceL2 sequence into a single LpNormalization node with p == 2. + fusion_lpnorm = FusionLpNormalization(onnx_model) + if fusion_lpnorm.apply(): + modified = True + + # Optionally, fuse ReduceMean sequence into a single LayerNormalization node. + if fuse_layernorm: + onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx") + + # Need opset >= 17 to use LayerNormalization. + if onnx_opset.version < 17: + logging.warning( + "Unable to fuse ReduceMean sequence into a LayerNormalization node. " + "ONNX model must use an opset >= 17 in order to use LayerNormalization, " + f"but found version {onnx_opset.version}. Please use onnx.version_converter to update your model." + ) + else: + fusion_layernorm = FusionLayerNormalization(onnx_model) + if fusion_layernorm.apply(): + modified = True + + if modified: + onnx_model.topological_sort() + onnx.save_model(model, model_output) + + return modified diff --git a/onnxruntime/python/tools/quantization/fusions/__init__.py b/onnxruntime/python/tools/quantization/fusions/__init__.py new file mode 100644 index 0000000000..f1576240a2 --- /dev/null +++ b/onnxruntime/python/tools/quantization/fusions/__init__.py @@ -0,0 +1,3 @@ +from .fusion import Fusion # noqa: F401 +from .fusion_gelu import FusionGelu # noqa: F401 +from .fusion_layernorm import FusionLayerNormalization # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/fusions/fusion.py b/onnxruntime/python/tools/quantization/fusions/fusion.py new file mode 100644 index 0000000000..456a75eec2 --- /dev/null +++ b/onnxruntime/python/tools/quantization/fusions/fusion.py @@ -0,0 +1,298 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +from collections import deque + +import onnx + +from ..onnx_model import ONNXModel + + +class Fusion: + """ + Base class for fusions. + """ + + def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str): + self.search_op_type: str = search_op_type + self.fused_op_type: str = fused_op_type + self.model: ONNXModel = model + self.nodes_to_remove: list = [] + self.nodes_to_add: list = [] + + def fuse( + self, + node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ): + """ + Interface function for derived fusion classes. Tries to fuse a node sequence containing + the specified node. + """ + raise NotImplementedError + + def apply(self) -> bool: + """ + Apply graph fusion on the entire model graph. + """ + input_name_to_nodes = self.model.input_name_to_nodes() + output_name_to_node = self.model.output_name_to_node() + + for node in self.model.nodes(): + if node.op_type == self.search_op_type: + self.fuse(node, input_name_to_nodes, output_name_to_node) + + self.model.remove_nodes(self.nodes_to_remove) + self.model.add_nodes(self.nodes_to_add) + + graph_updated = bool(self.nodes_to_remove or self.nodes_to_add) + + if graph_updated: + self.model.remove_unused_constant() + + return graph_updated + + @staticmethod + def is_safe_to_fuse_nodes( + nodes_to_remove: list[onnx.NodeProto], + keep_outputs: list[str], + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ) -> bool: + for node_to_remove in nodes_to_remove: + for output_to_remove in node_to_remove.output: + if output_to_remove in keep_outputs: + continue + + if output_to_remove in input_name_to_nodes: + for impacted_node in input_name_to_nodes[output_to_remove]: + if impacted_node not in nodes_to_remove: + # Not safe to remove nodes since output is used by impacted_node + return False + return True + + @staticmethod + def get_node_attribute(node: onnx.NodeProto, attribute_name: str): + for attr in node.attribute: + if attr.name == attribute_name: + value = onnx.helper.get_attribute_value(attr) + return value + return None + + @staticmethod + def input_index(node_output: str, child_node: onnx.NodeProto) -> int: + index = 0 + for input_name in child_node.input: + if input_name == node_output: + return index + index += 1 + return -1 + + @staticmethod + def tensor_shape_to_list(tensor_type) -> list[int]: + shape_list = [] + for d in tensor_type.shape.dim: + if d.HasField("dim_value"): + shape_list.append(d.dim_value) # known dimension + elif d.HasField("dim_param"): + shape_list.append(d.dim_param) # unknown dimension with symbolic name + else: + shape_list.append("?") # shall not happen + return shape_list + + def get_constant_input(self, node: onnx.NodeProto): + for i, inp in enumerate(node.input): + value = self.model.get_constant_value(inp) + if value is not None: + return i, value + + return None, None + + def find_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> int: + i, value = self.get_constant_input(node) + if value is not None and value.size == 1 and abs(value - expected_value) < delta: + return i + + return -1 + + def has_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> bool: + return self.find_constant_input(node, expected_value, delta) >= 0 + + def is_constant_with_specified_rank(self, output_name: str, rank: int) -> bool: + value = self.model.get_constant_value(output_name) + if value is None: + return False # Not an initializer + + if len(value.shape) != rank: + return False # Wrong dimensions + + return True + + def match_first_parent( + self, + node: onnx.NodeProto, + parent_op_type: str, + output_name_to_node: dict[str, onnx.NodeProto] | None = None, + exclude: list[onnx.NodeProto] = [], # noqa: B006 + ) -> tuple[onnx.NodeProto | None, int | None]: + """ + Find parent node based on constraints on op_type. + + Args: + node: current node. + parent_op_type (str): constraint of parent node op_type. + output_name_to_node (dict): dictionary with output name as key, and node as value. + exclude (list): list of nodes that are excluded (not allowed to match as parent). + + Returns: + parent: The matched parent node. None if not found. + index: The input index of matched parent node. None if not found. + """ + if output_name_to_node is None: + output_name_to_node = self.model.output_name_to_node() + + for i, inp in enumerate(node.input): + if inp in output_name_to_node: + parent = output_name_to_node[inp] + if parent.op_type == parent_op_type and parent not in exclude: + return parent, i + + return None, None + + def match_parent( + self, + node: onnx.NodeProto, + parent_op_type: str, + input_index: int | None = None, + output_name_to_node: dict[str, onnx.NodeProto] | None = None, + exclude: list[onnx.NodeProto] = [], # noqa: B006 + return_indice: list[int] | None = None, + ) -> onnx.NodeProto | None: + """ + Find parent node based on constraints on op_type and index. + When input_index is None, we will find the first parent node based on constraints, + and return_indice will be appended the corresponding input index. + + Args: + node (str): current node name. + parent_op_type (str): constraint of parent node op_type. + input_index (int or None): only check the parent given input index of current node. + output_name_to_node (dict): dictionary with output name as key, and node as value. + exclude (list): list of nodes that are excluded (not allowed to match as parent). + return_indice (list): a list to append the input index when input_index is None. + + Returns: + parent: The matched parent node. + """ + assert node is not None + assert input_index is None or input_index >= 0 + + if output_name_to_node is None: + output_name_to_node = self.model.output_name_to_node() + + if input_index is None: + parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude) + if return_indice is not None: + return_indice.append(index) + return parent + + if input_index >= len(node.input): + # Input index out of bounds. + return None + + parent = self.model.get_parent(node, input_index, output_name_to_node) + if parent is not None and parent.op_type == parent_op_type and parent not in exclude: + return parent + + return None + + def match_parent_path( + self, + node: onnx.NodeProto, + parent_op_types: list[str], + parent_input_index: list[int] | None = None, + output_name_to_node: dict[str, onnx.NodeProto] | None = None, + return_indice: list[int] | None = None, + ) -> list[onnx.NodeProto] | None: + """ + Find a sequence of input edges based on constraints on parent op_type and index. + When input_index is None, we will find the first parent node based on constraints, + and return_indice will be appended the corresponding input index. + + Args: + node (str): current node name. + parent_op_types (str): constraint of parent node op_type of each input edge. + parent_input_index (list): constraint of input index of each input edge. None means no constraint. + output_name_to_node (dict): dictionary with output name as key, and node as value. + return_indice (list): a list to append the input index + When there is no constraint on input index of an edge. + + Returns: + parents: a list of matched parent node. + """ + if parent_input_index is not None: + assert len(parent_input_index) == len(parent_op_types) + + if output_name_to_node is None: + output_name_to_node = self.model.output_name_to_node() + + current_node = node + matched_parents = [] + for i, op_type in enumerate(parent_op_types): + matched_parent = self.match_parent( + current_node, + op_type, + parent_input_index[i] if parent_input_index is not None else None, + output_name_to_node, + exclude=[], + return_indice=return_indice, + ) + if matched_parent is None: + return None + + matched_parents.append(matched_parent) + current_node = matched_parent + + return matched_parents + + def match_parent_paths( + self, + node: onnx.NodeProto, + paths: list[tuple[list[str], list[int]]], + output_name_to_node: dict[str, onnx.NodeProto], + ) -> tuple[int, list[onnx.NodeProto] | None, list[int] | None]: + """ + Find a matching parent path to the given node. + """ + for i, path in enumerate(paths): + return_indice = [] + matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice) + if matched: + return i, matched, return_indice + return -1, None, None + + def find_first_child_by_type( + self, + node: onnx.NodeProto, + child_type: str, + input_name_to_nodes: dict[str, list[onnx.NodeProto]] | None = None, + recursive: bool = True, + ) -> onnx.NodeProto | None: + children = self.model.get_children(node, input_name_to_nodes) + dq = deque(children) + while len(dq) > 0: + current_node = dq.pop() + if current_node.op_type == child_type: + return current_node + + if recursive: + children = self.model.get_children(current_node, input_name_to_nodes) + for child in children: + dq.appendleft(child) + + return None diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py b/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py new file mode 100644 index 0000000000..a20d6dbffd --- /dev/null +++ b/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py @@ -0,0 +1,269 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import onnx + +from ..onnx_model import ONNXModel +from .fusion import Fusion + + +class FusionGelu(Fusion): + def __init__(self, model: ONNXModel): + super().__init__(model, "Gelu", "Erf") + + def fuse( + self, + erf_node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ): + """ + Interface function that tries to fuse a node sequence containing an Erf node into a single + Gelu node. + """ + if ( + self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node) + or self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node) + or self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node) + ): + self.model.set_opset_import("com.microsoft", 1) + + def fuse_1( + self, + erf_node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ) -> bool: + """ + This pattern is from PyTorch model + Fuse Gelu with Erf into one node: + Pattern 1: + +-------Mul(0.5)---------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul --> + (B=1.4142...) (1) + + Pattern 2: + +------------------------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul -->Mul --> + (B=1.4142...) (1) (0.5) + + Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine. + """ + if erf_node.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[erf_node.output[0]] + if len(children) != 1 or children[0].op_type != "Add": + return False + add_after_erf = children[0] + + if not self.has_constant_input(add_after_erf, 1): + return False + + if add_after_erf.output[0] not in input_name_to_nodes: + return False + + children = input_name_to_nodes[add_after_erf.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return False + + mul_after_erf = children[0] + + div = self.match_parent(erf_node, "Div", 0, output_name_to_node) + if div is None: + return False + + if self.find_constant_input(div, 1.4142, delta=0.001) != 1: + return False + + subgraph_input = div.input[0] + + another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0 + if subgraph_input == mul_after_erf.input[another]: # pattern 2 + children = input_name_to_nodes[mul_after_erf.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return False + mul_half = children[0] + if not self.has_constant_input(mul_half, 0.5): + return False + subgraph_output = mul_half.output[0] + else: # pattern 1 + mul_half = self.match_parent(mul_after_erf, "Mul", another, output_name_to_node) + if mul_half is None: + return False + + if not self.has_constant_input(mul_half, 0.5): + return False + + if subgraph_input not in mul_half.input: + return False + + subgraph_output = mul_after_erf.output[0] + + subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half] + if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node): + return False + + self.nodes_to_remove.extend(subgraph_nodes) + fused_node = onnx.helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output]) + fused_node.domain = "com.microsoft" + self.nodes_to_add.append(fused_node) + return True + + def fuse_2( + self, + erf_node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ) -> bool: + """ + This pattern is from Keras model + Fuse Gelu with Erf into one node: + +------------------------------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul -->Mul + (B=1.4142...) (A=1) (A=0.5) + + Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine. + """ + if erf_node.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[erf_node.output[0]] + if len(children) != 1 or children[0].op_type != "Add": + return False + add_after_erf = children[0] + + if not self.has_constant_input(add_after_erf, 1): + return False + + if add_after_erf.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[add_after_erf.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return False + mul_after_erf = children[0] + + if not self.has_constant_input(mul_after_erf, 0.5): + return False + + if mul_after_erf.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[mul_after_erf.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return False + mul = children[0] + + div = self.match_parent(erf_node, "Div", 0, output_name_to_node) + if div is None: + return False + + sqrt_node = None + if self.find_constant_input(div, 1.4142, delta=0.001) != 1: + sqrt_node = self.match_parent(div, "Sqrt", 1, output_name_to_node) + if sqrt_node is None: + return False + if not self.has_constant_input(sqrt_node, 2.0): + return False + + root_node = self.model.get_parent(div, 0, output_name_to_node) + if root_node is None: + return False + + if root_node.output[0] not in mul.input: + return False + + subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul] + if sqrt_node: + subgraph_nodes.append(sqrt_node) + + if not self.is_safe_to_fuse_nodes(subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node): + return False + + self.nodes_to_remove.extend(subgraph_nodes) + fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]]) + fused_node.domain = "com.microsoft" + self.nodes_to_add.append(fused_node) + return True + + def fuse_3( + self, + erf_node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ) -> bool: + """ + This pattern is from TensorFlow model + Fuse Gelu with Erf into one node: + +----------------------------------------------+ + | | + | v + [root] --> Mul -----> Erf --> Add --> Mul -->Mul + (A=0.7071067690849304) (B=1) (B=0.5) + + Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine. + """ + + if erf_node.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[erf_node.output[0]] + if len(children) != 1 or children[0].op_type != "Add": + return False + add_after_erf = children[0] + + if not self.has_constant_input(add_after_erf, 1): + return False + + if add_after_erf.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[add_after_erf.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return False + mul_half = children[0] + + if not self.has_constant_input(mul_half, 0.5): + return False + + first_mul = self.match_parent(erf_node, "Mul", 0, output_name_to_node) + if first_mul is None: + return False + + i = self.find_constant_input(first_mul, 0.7071067690849304, delta=0.001) + if i < 0: + return False + + root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node) + if root_node is None: + return False + + if mul_half.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[mul_half.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return False + last_mul = children[0] + + if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]): + return False + + subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul] + if not self.is_safe_to_fuse_nodes( + subgraph_nodes, + [last_mul.output[0]], + input_name_to_nodes, + output_name_to_node, + ): + return False + + self.nodes_to_remove.extend(subgraph_nodes) + fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]]) + fused_node.domain = "com.microsoft" + self.nodes_to_add.append(fused_node) + return True diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py new file mode 100644 index 0000000000..d7fb89236d --- /dev/null +++ b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py @@ -0,0 +1,134 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import onnx + +from ..onnx_model import ONNXModel +from .fusion import Fusion + + +class FusionLayerNormalization(Fusion): + def __init__(self, model: ONNXModel): + super().__init__(model, "LayerNormalization", "ReduceMean") + + def fuse( + self, + reduce_mean_node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ): + """ + Interface function that tries to fuse a node sequence containing a ReduceMean node into a single + LayerNormalization node. + + +----------------------+ + | | + | v + [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add + (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^ + | | + +-------------------------------------------------+ + + It also handles cases of duplicated sub nodes exported from older version of PyTorch: + + +----------------------+ + | v + | +-------> Sub-----------------------------------------------+ + | | | + | | v + [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add + | ^ + | | + +----------------------+ + """ + children = self.model.get_children(reduce_mean_node, input_name_to_nodes) + if len(children) == 0 or len(children) > 2: + return + + root_input = reduce_mean_node.input[0] + + if children[0].op_type != "Sub" or children[0].input[0] != root_input: + return + + if len(children) == 2: + if children[1].op_type != "Sub" or children[1].input[0] != root_input: + return + + div_node = None + for child in children: + div_node = self.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False) + if div_node is not None: + break + if div_node is None: + return + + path_id, parent_nodes, _ = self.match_parent_paths( + div_node, + [ + (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]), + ( + ["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], + [1, 0, 0, 0, 0, 0], + ), + ], + output_name_to_node, + ) + if path_id < 0: + return + + sub_node = parent_nodes[-1] + if sub_node not in children: + return + + second_add_node = parent_nodes[1] + i, add_weight = self.get_constant_input(second_add_node) + if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: + # Skip fusion since epsilon value is not expected. + return + + pow_node = parent_nodes[3] + if self.find_constant_input(pow_node, 2.0) != 1: + return + + mul_node = input_name_to_nodes[div_node.output[0]][0] + if mul_node.op_type != "Mul": + return + + last_add_node = input_name_to_nodes[mul_node.output[0]][0] + if last_add_node.op_type != "Add": + return + + subgraph_nodes = [reduce_mean_node] + subgraph_nodes.extend(children) + subgraph_nodes.extend(parent_nodes[:-1]) + + subgraph_nodes.extend([last_add_node, mul_node, div_node]) + if not self.is_safe_to_fuse_nodes( + subgraph_nodes, + last_add_node.output, + input_name_to_nodes, + output_name_to_node, + ): + return + + weight_input = mul_node.input[1 - self.input_index(div_node.output[0], mul_node)] + if not self.is_constant_with_specified_rank(weight_input, 1): + return + + bias_input = last_add_node.input[1 - self.input_index(mul_node.output[0], last_add_node)] + if not self.is_constant_with_specified_rank(bias_input, 1): + return + + self.nodes_to_remove.extend(subgraph_nodes) + + normalize_node = onnx.helper.make_node( + "LayerNormalization", + inputs=[reduce_mean_node.input[0], weight_input, bias_input], + outputs=[last_add_node.output[0]], + ) + normalize_node.attribute.extend([onnx.helper.make_attribute("epsilon", float(add_weight))]) + self.nodes_to_add.append(normalize_node) diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index e4342908f6..4591c9c950 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -1,3 +1,7 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- from pathlib import Path import onnx @@ -114,6 +118,14 @@ class ONNXModel: def opset_import(self): return self.model.opset_import + def set_opset_import(self, domain, version): + for opset in self.model.opset_import: + if opset.domain == domain: + opset.version = version + return + + self.model.opset_import.extend([onnx_helper.make_opsetid(domain, version)]) + def remove_node(self, node): if node in self.model.graph.node: self.model.graph.node.remove(node) @@ -140,6 +152,49 @@ class ONNXModel: return tensor return None + def find_graph_input(self, input_name): + for input in self.model.graph.input: + if input.name == input_name: + return input + return None + + def find_graph_output(self, output_name): + for output in self.model.graph.output: + if output.name == output_name: + return output + return None + + def get_tensor_type(self, tensor_name: str): + tensor_type_map = {obj.name: obj.type for obj in self.model.graph.value_info} + + if tensor_name in tensor_type_map: + return tensor_type_map[tensor_name].tensor_type + + g_input = self.find_graph_input(tensor_name) + if g_input: + return g_input.type.tensor_type + + g_output = self.find_graph_output(tensor_name) + if g_output: + return g_output.type.tensor_type + + return None + + def get_constant_value(self, output_name): + for node in self.model.graph.node: + if node.op_type == "Constant": + if node.output[0] == output_name: + for attr in node.attribute: + if attr.name == "value": + return onnx_numpy_helper.to_array(attr.t) + + # Fallback to initializer since constant folding may have been applied. + initializer = self.get_initializer(output_name) + if initializer is not None: + return onnx_numpy_helper.to_array(initializer) + + return None + def get_initializer_name_set(self): return {initializer.name for initializer in self.model.graph.initializer} @@ -167,17 +222,19 @@ class ONNXModel: input_name_to_nodes = {} for node in self.model.graph.node: for input_name in node.input: - if input_name not in input_name_to_nodes: - input_name_to_nodes[input_name] = [node] - else: - input_name_to_nodes[input_name].append(node) + if input_name: # Could be empty when it is optional + if input_name not in input_name_to_nodes: + input_name_to_nodes[input_name] = [node] + else: + input_name_to_nodes[input_name].append(node) return input_name_to_nodes def output_name_to_node(self): output_name_to_node = {} for node in self.model.graph.node: for output_name in node.output: - output_name_to_node[output_name] = node + if output_name: # Could be empty when it is optional + output_name_to_node[output_name] = node return output_name_to_node def get_children(self, node, input_name_to_nodes=None): diff --git a/setup.py b/setup.py index 2ede39915c..44c97937eb 100644 --- a/setup.py +++ b/setup.py @@ -408,6 +408,7 @@ packages = [ "onnxruntime.quantization", "onnxruntime.quantization.operators", "onnxruntime.quantization.CalTableFlatBuffers", + "onnxruntime.quantization.fusions", "onnxruntime.quantization.execution_providers.qnn", "onnxruntime.transformers", "onnxruntime.transformers.models.bart",