diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 746ec55d47..416d50a0b2 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -130,7 +130,7 @@ if (onnxruntime_ENABLE_EAGER_MODE) endif() if (MSVC) target_compile_options(onnxruntime_pybind11_state PRIVATE "/wd4100" "/wd4324" "/wd4458" "/wd4127" "/wd4193" "/wd4624" "/wd4702") - target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj" "/wd4275" "/wd4244" "/wd4267") + target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj" "/wd4275" "/wd4244" "/wd4267" "/wd4067") endif() endif() diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py index 8bb882571b..d2d3076455 100644 --- a/orttraining/orttraining/eager/opgen/opgen/atenops.py +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -8,6 +8,11 @@ from opgen.generator import \ from opgen.onnxops import * +import torch +from packaging import version + +TORCH_API_CHANGE_VERSION = "1.11.0" + kMSDomain = 'onnxruntime::kMSDomain' class ReluGrad(ONNXOp): @@ -79,7 +84,6 @@ hand_implemented = { 'aten::softshrink': Shrink('self', bias='lambd', lambd='lambd'), #yes, bias is set to 'lambd' 'aten::hardshrink': Shrink('self', bias=0, lambd='lambd'), 'aten::gelu' : Gelu('self'), - 'aten::gelu_backward' : GeluGrad('grad', 'self'), 'aten::max' : ReduceMax('self', keepdims=1), 'aten::min' : ReduceMin('self', keepdims=1), 'aten::_cat': Concat('tensors', 'dim'), @@ -95,6 +99,13 @@ hand_implemented = { 'aten::gt.Scalar_out' : MakeTorchFallback(), } +# Signature of gelu_backward was changed in this commit id 983ba5e585485ed61a0c0012ef6944f5685e3d97 and PR 61439 +# This is done to make sure it is backward and future compatible +if version.parse(torch.__version__) <= version.parse(TORCH_API_CHANGE_VERSION): + hand_implemented['aten::gelu_backward'] = GeluGrad('grad', 'self') +else: + hand_implemented['aten::gelu_backward'] = GeluGrad('grad_output', 'self') + ops = {**ops, **hand_implemented} # TODO: this is a temporary allowlist for ops need type promotion # Need to enhance the support for onnx type constrains to automatically diff --git a/orttraining/orttraining/eager/ort_eager_common.h b/orttraining/orttraining/eager/ort_eager_common.h index 3de3c2d1b8..e7f54b8d33 100644 --- a/orttraining/orttraining/eager/ort_eager_common.h +++ b/orttraining/orttraining/eager/ort_eager_common.h @@ -4,4 +4,5 @@ #pragma once // include the pybind header first, it will disable linking to pythonX_d.lib on Windows in debug mode #include "python/onnxruntime_pybind_state_common.h" -#include \ No newline at end of file +#include +#include \ No newline at end of file