onnxruntime/onnxruntime/python/tools/transformers/fusion_gelu_approximation.py
Tianlei Wu d65aa5400c
clean up transformers scripts (#17179)
(1) Remove class BertOptimizationOptions that has been deprecated a long
time ago
(2) Move sys path setttings to `__init__.py`, and update imports
(3) Fix bert_perf_test to run properly.
(4) Fix a onnx path in a whisper test case
(5) Fix a few typos
(6) Update comments in bert_perf_test regarding to graph inputs
2023-08-17 23:14:49 -07:00

25 lines
1,004 B
Python

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from fusion_base import Fusion
from onnx import helper
from onnx_model import OnnxModel
class FusionGeluApproximation(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "FastGelu", ["Gelu", "BiasGelu"], "GeluApproximation")
def fuse(self, node, input_name_to_nodes, output_name_to_node):
new_node = helper.make_node(
"FastGelu",
inputs=node.input,
outputs=node.output,
name=self.model.create_node_name("FastGelu", node.op_type + "_Approximation"),
)
new_node.domain = "com.microsoft"
self.nodes_to_remove.append(node)
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name