onnxruntime/onnxruntime/python/tools/transformers/fusion_gelu_approximation.py
Ye Wang 5e8086ad8e
Support fusions inside subgraphs in optimizer tool (#7701)
* skip subgraph when updating model

* intreim checkin

* interim checkin 2

* support transformers optimizations in subgraph

* change more files

* fix comments typo
2021-05-17 12:43:55 -07:00

24 lines
1.1 KiB
Python

#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#--------------------------------------------------------------------------
from logging import getLogger
from onnx import helper
from onnx_model import OnnxModel
from fusion_base import Fusion
class FusionGeluApproximation(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, 'FastGelu', ['Gelu', 'BiasGelu'], 'GeluApproximation')
def fuse(self, node, input_name_to_nodes, output_name_to_node):
new_node = helper.make_node("FastGelu",
inputs=node.input,
outputs=node.output,
name=self.model.create_node_name("FastGelu", node.op_type + "_Approximation"))
new_node.domain = "com.microsoft"
self.nodes_to_remove.append(node)
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name