mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
#
|
|
# Register pytorch symbolic for export using ONNX Runtime contrib ops
|
|
|
|
from torch.onnx import register_custom_op_symbolic
|
|
|
|
|
|
_onnx_opset_version = 1
|
|
|
|
|
|
def register_custom_op():
|
|
"""
|
|
This function registers symbolic functions for
|
|
custom ops that are implemented as part of ONNX Runtime
|
|
"""
|
|
|
|
# Symbolic definition
|
|
def inverse(g, self):
|
|
return g.op("com.microsoft::Inverse", self)
|
|
|
|
def gelu(g, self):
|
|
return g.op("com.microsoft::Gelu", self)
|
|
|
|
def triu(g, self, diagonal):
|
|
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1)
|
|
|
|
def tril(g, self, diagonal):
|
|
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=0)
|
|
|
|
# Op Registration
|
|
register_custom_op_symbolic('::inverse', inverse, _onnx_opset_version)
|
|
register_custom_op_symbolic('::gelu', gelu, _onnx_opset_version)
|
|
register_custom_op_symbolic('::triu', triu, _onnx_opset_version)
|
|
register_custom_op_symbolic('::tril', tril, _onnx_opset_version)
|