mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
### Description - Updates the `qnn_preprocess_model()` method to set a name for any new nodes added to the graph (due to fusion). - Updates the `qnn_preprocess_model()` method to set a name for any unnamed nodes that previously existed in the original graph. - Adds unit tests for fusions (previously missing) - Checks that fused node names exist and are unique - Checks that fused graph is equivalent to original graph ### Motivation and Context Nodes are not strictly required to have names. However, a planned/upcoming feature to support mixed-precision (integer) quantized models needs nodes to have names.
135 lines
5 KiB
Python
135 lines
5 KiB
Python
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License. See License.txt in the project root for
|
|
# license information.
|
|
# --------------------------------------------------------------------------
|
|
from __future__ import annotations
|
|
|
|
import onnx
|
|
|
|
from ..onnx_model import ONNXModel
|
|
from .fusion import Fusion
|
|
|
|
|
|
class FusionLayerNormalization(Fusion):
|
|
def __init__(self, model: ONNXModel):
|
|
super().__init__(model, "LayerNormalization", "ReduceMean")
|
|
|
|
def fuse(
|
|
self,
|
|
reduce_mean_node: onnx.NodeProto,
|
|
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
|
output_name_to_node: dict[str, onnx.NodeProto],
|
|
):
|
|
"""
|
|
Interface function that tries to fuse a node sequence containing a ReduceMean node into a single
|
|
LayerNormalization node.
|
|
|
|
+----------------------+
|
|
| |
|
|
| v
|
|
[Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
|
|
(axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^
|
|
| |
|
|
+-------------------------------------------------+
|
|
|
|
It also handles cases of duplicated sub nodes exported from older version of PyTorch:
|
|
|
|
+----------------------+
|
|
| v
|
|
| +-------> Sub-----------------------------------------------+
|
|
| | |
|
|
| | v
|
|
[Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
|
|
| ^
|
|
| |
|
|
+----------------------+
|
|
"""
|
|
children = self.model.get_children(reduce_mean_node, input_name_to_nodes)
|
|
if len(children) == 0 or len(children) > 2:
|
|
return
|
|
|
|
root_input = reduce_mean_node.input[0]
|
|
|
|
if children[0].op_type != "Sub" or children[0].input[0] != root_input:
|
|
return
|
|
|
|
if len(children) == 2:
|
|
if children[1].op_type != "Sub" or children[1].input[0] != root_input:
|
|
return
|
|
|
|
div_node = None
|
|
for child in children:
|
|
div_node = self.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
|
|
if div_node is not None:
|
|
break
|
|
if div_node is None:
|
|
return
|
|
|
|
path_id, parent_nodes, _ = self.match_parent_paths(
|
|
div_node,
|
|
[
|
|
(["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
|
|
(
|
|
["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"],
|
|
[1, 0, 0, 0, 0, 0],
|
|
),
|
|
],
|
|
output_name_to_node,
|
|
)
|
|
if path_id < 0:
|
|
return
|
|
|
|
sub_node = parent_nodes[-1]
|
|
if sub_node not in children:
|
|
return
|
|
|
|
second_add_node = parent_nodes[1]
|
|
i, add_weight = self.get_constant_input(second_add_node)
|
|
if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
|
|
# Skip fusion since epsilon value is not expected.
|
|
return
|
|
|
|
pow_node = parent_nodes[3]
|
|
if self.find_constant_input(pow_node, 2.0) != 1:
|
|
return
|
|
|
|
mul_node = input_name_to_nodes[div_node.output[0]][0]
|
|
if mul_node.op_type != "Mul":
|
|
return
|
|
|
|
last_add_node = input_name_to_nodes[mul_node.output[0]][0]
|
|
if last_add_node.op_type != "Add":
|
|
return
|
|
|
|
subgraph_nodes = [reduce_mean_node]
|
|
subgraph_nodes.extend(children)
|
|
subgraph_nodes.extend(parent_nodes[:-1])
|
|
|
|
subgraph_nodes.extend([last_add_node, mul_node, div_node])
|
|
if not self.is_safe_to_fuse_nodes(
|
|
subgraph_nodes,
|
|
last_add_node.output,
|
|
input_name_to_nodes,
|
|
output_name_to_node,
|
|
):
|
|
return
|
|
|
|
weight_input = mul_node.input[1 - self.input_index(div_node.output[0], mul_node)]
|
|
if not self.is_constant_with_specified_rank(weight_input, 1):
|
|
return
|
|
|
|
bias_input = last_add_node.input[1 - self.input_index(mul_node.output[0], last_add_node)]
|
|
if not self.is_constant_with_specified_rank(bias_input, 1):
|
|
return
|
|
|
|
self.nodes_to_remove.extend(subgraph_nodes)
|
|
|
|
normalize_node = onnx.helper.make_node(
|
|
"LayerNormalization",
|
|
name=self.create_unique_node_name(),
|
|
inputs=[reduce_mean_node.input[0], weight_input, bias_input],
|
|
outputs=[last_add_node.output[0]],
|
|
)
|
|
normalize_node.attribute.extend([onnx.helper.make_attribute("epsilon", float(add_weight))])
|
|
self.nodes_to_add.append(normalize_node)
|