onnxruntime/tools/python/register_custom_ops_pytorch_exporter.py

58 lines
1.9 KiB
Python

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# Register pytorch symbolic for export using ONNX Runtime contrib ops
from torch.onnx import register_custom_op_symbolic
_onnx_opset_version = 1
def register_custom_op():
"""
This function registers symbolic functions for
custom ops that are implemented as part of ONNX Runtime
"""
# Symbolic definition
def inverse(g, self):
return g.op("com.microsoft::Inverse", self)
def gelu(g, self):
return g.op("com.microsoft::Gelu", self)
def triu(g, self, diagonal):
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1)
def tril(g, self, diagonal):
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=0)
# Op Registration
register_custom_op_symbolic('::inverse', inverse, _onnx_opset_version)
register_custom_op_symbolic('::gelu', gelu, _onnx_opset_version)
register_custom_op_symbolic('::triu', triu, _onnx_opset_version)
register_custom_op_symbolic('::tril', tril, _onnx_opset_version)
def unregister_custom_op():
"""
This function unregisters symbolic functions for
custom ops that are implemented as part of ONNX Runtime
"""
import torch.onnx.symbolic_registry as sym_registry
# TODO: replace this once PyTorch supports unregister natively.
def unregister(name, opset_version):
ns, kind = name.split("::")
from torch.onnx.symbolic_helper import _onnx_stable_opsets
for version in _onnx_stable_opsets:
if version >= opset_version and sym_registry.is_registered_op(kind, ns, version):
del sym_registry._registry[(ns, version)][kind]
unregister('::inverse', _onnx_opset_version)
unregister('::gelu', _onnx_opset_version)
unregister('::triu', _onnx_opset_version)
unregister('::tril', _onnx_opset_version)