changing gelu backward op and adding required files (#10813)

* changing gelu backward op and adding required files

* cleaning up file and adding comments
This commit is contained in:
Abhishek Jindal 2022-03-09 16:54:51 -08:00 committed by GitHub
parent 0293e525ea
commit 1c313f4476
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 3 deletions

View file

@ -130,7 +130,7 @@ if (onnxruntime_ENABLE_EAGER_MODE)
endif()
if (MSVC)
target_compile_options(onnxruntime_pybind11_state PRIVATE "/wd4100" "/wd4324" "/wd4458" "/wd4127" "/wd4193" "/wd4624" "/wd4702")
target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj" "/wd4275" "/wd4244" "/wd4267")
target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj" "/wd4275" "/wd4244" "/wd4267" "/wd4067")
endif()
endif()

View file

@ -8,6 +8,11 @@ from opgen.generator import \
from opgen.onnxops import *
import torch
from packaging import version
TORCH_API_CHANGE_VERSION = "1.11.0"
kMSDomain = 'onnxruntime::kMSDomain'
class ReluGrad(ONNXOp):
@ -79,7 +84,6 @@ hand_implemented = {
'aten::softshrink': Shrink('self', bias='lambd', lambd='lambd'), #yes, bias is set to 'lambd'
'aten::hardshrink': Shrink('self', bias=0, lambd='lambd'),
'aten::gelu' : Gelu('self'),
'aten::gelu_backward' : GeluGrad('grad', 'self'),
'aten::max' : ReduceMax('self', keepdims=1),
'aten::min' : ReduceMin('self', keepdims=1),
'aten::_cat': Concat('tensors', 'dim'),
@ -95,6 +99,13 @@ hand_implemented = {
'aten::gt.Scalar_out' : MakeTorchFallback(),
}
# Signature of gelu_backward was changed in this commit id 983ba5e585485ed61a0c0012ef6944f5685e3d97 and PR 61439
# This is done to make sure it is backward and future compatible
if version.parse(torch.__version__) <= version.parse(TORCH_API_CHANGE_VERSION):
hand_implemented['aten::gelu_backward'] = GeluGrad('grad', 'self')
else:
hand_implemented['aten::gelu_backward'] = GeluGrad('grad_output', 'self')
ops = {**ops, **hand_implemented}
# TODO: this is a temporary allowlist for ops need type promotion
# Need to enhance the support for onnx type constrains to automatically

View file

@ -4,4 +4,5 @@
#pragma once
// include the pybind header first, it will disable linking to pythonX_d.lib on Windows in debug mode
#include "python/onnxruntime_pybind_state_common.h"
#include <torch/extension.h>
#include <torch/extension.h>
#include <ATen/Operators.h>