mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
changing gelu backward op and adding required files (#10813)
* changing gelu backward op and adding required files * cleaning up file and adding comments
This commit is contained in:
parent
0293e525ea
commit
1c313f4476
3 changed files with 15 additions and 3 deletions
|
|
@ -130,7 +130,7 @@ if (onnxruntime_ENABLE_EAGER_MODE)
|
|||
endif()
|
||||
if (MSVC)
|
||||
target_compile_options(onnxruntime_pybind11_state PRIVATE "/wd4100" "/wd4324" "/wd4458" "/wd4127" "/wd4193" "/wd4624" "/wd4702")
|
||||
target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj" "/wd4275" "/wd4244" "/wd4267")
|
||||
target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj" "/wd4275" "/wd4244" "/wd4267" "/wd4067")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,11 @@ from opgen.generator import \
|
|||
|
||||
from opgen.onnxops import *
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
TORCH_API_CHANGE_VERSION = "1.11.0"
|
||||
|
||||
kMSDomain = 'onnxruntime::kMSDomain'
|
||||
|
||||
class ReluGrad(ONNXOp):
|
||||
|
|
@ -79,7 +84,6 @@ hand_implemented = {
|
|||
'aten::softshrink': Shrink('self', bias='lambd', lambd='lambd'), #yes, bias is set to 'lambd'
|
||||
'aten::hardshrink': Shrink('self', bias=0, lambd='lambd'),
|
||||
'aten::gelu' : Gelu('self'),
|
||||
'aten::gelu_backward' : GeluGrad('grad', 'self'),
|
||||
'aten::max' : ReduceMax('self', keepdims=1),
|
||||
'aten::min' : ReduceMin('self', keepdims=1),
|
||||
'aten::_cat': Concat('tensors', 'dim'),
|
||||
|
|
@ -95,6 +99,13 @@ hand_implemented = {
|
|||
'aten::gt.Scalar_out' : MakeTorchFallback(),
|
||||
}
|
||||
|
||||
# Signature of gelu_backward was changed in this commit id 983ba5e585485ed61a0c0012ef6944f5685e3d97 and PR 61439
|
||||
# This is done to make sure it is backward and future compatible
|
||||
if version.parse(torch.__version__) <= version.parse(TORCH_API_CHANGE_VERSION):
|
||||
hand_implemented['aten::gelu_backward'] = GeluGrad('grad', 'self')
|
||||
else:
|
||||
hand_implemented['aten::gelu_backward'] = GeluGrad('grad_output', 'self')
|
||||
|
||||
ops = {**ops, **hand_implemented}
|
||||
# TODO: this is a temporary allowlist for ops need type promotion
|
||||
# Need to enhance the support for onnx type constrains to automatically
|
||||
|
|
|
|||
|
|
@ -4,4 +4,5 @@
|
|||
#pragma once
|
||||
// include the pybind header first, it will disable linking to pythonX_d.lib on Windows in debug mode
|
||||
#include "python/onnxruntime_pybind_state_common.h"
|
||||
#include <torch/extension.h>
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/Operators.h>
|
||||
Loading…
Reference in a new issue