2022-04-26 16:35:16 +00:00
|
|
|
# -------------------------------------------------------------------------
|
2020-05-28 08:16:41 +00:00
|
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
# Licensed under the MIT License.
|
2022-04-26 16:35:16 +00:00
|
|
|
# --------------------------------------------------------------------------
|
2020-05-28 08:16:41 +00:00
|
|
|
|
2022-04-26 16:35:16 +00:00
|
|
|
from fusion_base import Fusion
|
2020-05-28 08:16:41 +00:00
|
|
|
from onnx import helper
|
2020-06-16 16:36:51 +00:00
|
|
|
from onnx_model import OnnxModel
|
2020-05-28 08:16:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class FusionGeluApproximation(Fusion):
|
|
|
|
|
def __init__(self, model: OnnxModel):
|
2022-04-26 16:35:16 +00:00
|
|
|
super().__init__(model, "FastGelu", ["Gelu", "BiasGelu"], "GeluApproximation")
|
2020-05-28 08:16:41 +00:00
|
|
|
|
|
|
|
|
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
2022-04-26 16:35:16 +00:00
|
|
|
new_node = helper.make_node(
|
|
|
|
|
"FastGelu",
|
|
|
|
|
inputs=node.input,
|
|
|
|
|
outputs=node.output,
|
|
|
|
|
name=self.model.create_node_name("FastGelu", node.op_type + "_Approximation"),
|
|
|
|
|
)
|
2020-05-28 08:16:41 +00:00
|
|
|
new_node.domain = "com.microsoft"
|
|
|
|
|
self.nodes_to_remove.append(node)
|
|
|
|
|
self.nodes_to_add.append(new_node)
|
2021-05-17 19:43:55 +00:00
|
|
|
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|