mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Enable PythonOp for --enable_training_torch_interop build (#12539)
* enable PythonOp by default when --enable_training_torch_interop is enabled during build * clean up * fix * fix comment * fix * fix tests * fix fallback test * pylint format * refine based on comments
This commit is contained in:
parent
b59ccbc75b
commit
24eab921be
6 changed files with 112 additions and 50 deletions
|
|
@ -486,6 +486,13 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
pool.UnRegisterFunctions();
|
||||
#endif
|
||||
});
|
||||
m.def("is_torch_interop_default_on", []() -> bool {
|
||||
#ifdef ENABLE_TRAINING_TORCH_INTEROP
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
});
|
||||
|
||||
py::class_<TrainingConfigurationResult> config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc");
|
||||
config_result.def(py::init())
|
||||
|
|
|
|||
|
|
@ -8,47 +8,81 @@ class Enabler(object):
|
|||
def __init__(self):
|
||||
self._state = False
|
||||
|
||||
# This is to indicate whether custom autograd support has been enabled or not (despite of the current state)
|
||||
# in current process.
|
||||
self._already_enabled = False
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def already_enabled(self):
|
||||
return self._already_enabled
|
||||
|
||||
@state.setter
|
||||
def state(self, val):
|
||||
self._state = val
|
||||
if self._already_enabled is False and val is True:
|
||||
self._already_enabled = True
|
||||
|
||||
|
||||
custom_autograd_function_enabler = Enabler()
|
||||
|
||||
# Legacy API to enable the custom autograd, keep its name with default value for compatibility.
|
||||
def enable_custom_autograd_support(enable=True):
|
||||
|
||||
def enable_custom_autograd_support():
|
||||
# Initialize static objects needed to run custom autograd.Function's.
|
||||
|
||||
from onnxruntime.capi._pybind_state import (
|
||||
register_forward_runner,
|
||||
register_backward_runner,
|
||||
unregister_python_functions,
|
||||
)
|
||||
from torch.onnx import register_custom_op_symbolic
|
||||
from ._custom_autograd_function_exporter import _export
|
||||
from ._custom_autograd_function_runner import call_python_forward_function, call_python_backward_function
|
||||
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils
|
||||
import atexit
|
||||
|
||||
register_forward_runner(call_python_forward_function)
|
||||
register_backward_runner(call_python_backward_function)
|
||||
from torch.onnx import register_custom_op_symbolic, unregister_custom_op_symbolic
|
||||
|
||||
# Unregister all python functions automatically upon normal interpreter termination.
|
||||
atexit.register(unregister_python_functions)
|
||||
# Clear all gradient functions, to avoid a deadlock issue.
|
||||
# Check the called function for more detailed comments.
|
||||
atexit.register(torch_interop_utils.clear_all_grad_fns)
|
||||
from onnxruntime.capi._pybind_state import (
|
||||
register_backward_runner,
|
||||
register_forward_runner,
|
||||
unregister_python_functions,
|
||||
)
|
||||
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils
|
||||
|
||||
try:
|
||||
# This is for the latest Pytorch nightly after this commit:
|
||||
# https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec
|
||||
register_custom_op_symbolic("prim::PythonOp", _export, 1)
|
||||
except:
|
||||
# This applies to Pytorch 1.9 and 1.9.1.
|
||||
register_custom_op_symbolic("::prim_PythonOp", _export, 1)
|
||||
from ._custom_autograd_function_exporter import _export
|
||||
from ._custom_autograd_function_runner import call_python_backward_function, call_python_forward_function
|
||||
|
||||
custom_autograd_function_enabler.state = True
|
||||
if enable is True:
|
||||
if custom_autograd_function_enabler.already_enabled is False:
|
||||
# Initialize static objects needed to run custom autograd.Function's.
|
||||
register_forward_runner(call_python_forward_function)
|
||||
register_backward_runner(call_python_backward_function)
|
||||
|
||||
# Unregister all python functions automatically upon normal interpreter termination.
|
||||
atexit.register(unregister_python_functions)
|
||||
# Clear all gradient functions, to avoid a deadlock issue.
|
||||
# Check the called function for more detailed comments.
|
||||
atexit.register(torch_interop_utils.clear_all_grad_fns)
|
||||
|
||||
try:
|
||||
# This is for the latest Pytorch nightly after this commit:
|
||||
# https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec
|
||||
register_custom_op_symbolic("prim::PythonOp", _export, 1)
|
||||
except:
|
||||
# This applies to Pytorch 1.9 and 1.9.1.
|
||||
register_custom_op_symbolic("::prim_PythonOp", _export, 1)
|
||||
|
||||
custom_autograd_function_enabler.state = True
|
||||
else:
|
||||
if custom_autograd_function_enabler.already_enabled is True:
|
||||
# We don't need remove the registered runner because it won't be used if we disable the feature.
|
||||
# But we need unregister the PythonOp custom operator function.
|
||||
try:
|
||||
# This is for the latest Pytorch nightly after this commit:
|
||||
# https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec
|
||||
unregister_custom_op_symbolic("prim::PythonOp", 1)
|
||||
except:
|
||||
# This applies to Pytorch 1.9 and 1.9.1.
|
||||
unregister_custom_op_symbolic("::prim_PythonOp", 1)
|
||||
|
||||
custom_autograd_function_enabler.state = False
|
||||
|
||||
|
||||
from onnxruntime.capi._pybind_state import is_torch_interop_default_on
|
||||
|
||||
# Enable the custom autograd by default when PythonOp backend support is enabled during build.
|
||||
enable_custom_autograd_support(is_torch_interop_default_on())
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
|||
from onnxruntime.training import ortmodule
|
||||
|
||||
from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils
|
||||
from ._custom_autograd_function import custom_autograd_function_enabler
|
||||
from ._custom_autograd_function_exporter import _post_process_after_export
|
||||
from ._fallback import (
|
||||
ORTModuleDeviceException,
|
||||
|
|
@ -147,6 +146,8 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
self._run_symbolic_shape_infer = True
|
||||
|
||||
# PyTorch custom Autograd function support
|
||||
from ._custom_autograd_function import custom_autograd_function_enabler
|
||||
|
||||
self._enable_custom_autograd_function = custom_autograd_function_enabler.state
|
||||
|
||||
self._input_info = None
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
import copy
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer
|
||||
from onnxruntime.training import orttrainer
|
||||
|
||||
try:
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support
|
||||
from onnxruntime.training.ortmodule._graph_execution_manager_factory import GraphExecutionManagerFactory
|
||||
from onnxruntime.training.ortmodule._fallback import ORTModuleInitException
|
||||
from onnxruntime.training.ortmodule._graph_execution_manager_factory import GraphExecutionManagerFactory
|
||||
except ImportError:
|
||||
# Some pipelines do not contain ORTModule
|
||||
pass
|
||||
|
|
@ -231,10 +231,6 @@ def assert_values_are_close(input, other, rtol=1e-05, atol=1e-06):
|
|||
assert False, err_msg
|
||||
|
||||
|
||||
def enable_custom_autograd_function(module):
|
||||
enable_custom_autograd_support()
|
||||
|
||||
|
||||
def _run_model_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False):
|
||||
if is_eval_mode:
|
||||
model.eval()
|
||||
|
|
@ -287,7 +283,6 @@ def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode=
|
|||
with torch.no_grad():
|
||||
model = copy.deepcopy(model)
|
||||
model.to(device)
|
||||
enable_custom_autograd_function(model)
|
||||
model = ORTModule(model)
|
||||
|
||||
return _run_model_on_device(device, model, input_list, label_input, is_eval_mode, run_forward_twice)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: disable=missing-docstring
|
||||
# pylint: disable=C0103
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import pytest
|
||||
|
|
@ -22,7 +25,19 @@ def torch_version_lower_than(v):
|
|||
return LooseVersion(torch.__version__) < LooseVersion(v)
|
||||
|
||||
|
||||
def test_GeLU():
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def run_before_test_session(request):
|
||||
def insert_disable_fallback_in_env():
|
||||
os.environ["ORTMODULE_FALLBACK_POLICY"] = "FALLBACK_DISABLE"
|
||||
|
||||
def remove_disable_fallback_from_env():
|
||||
del os.environ["ORTMODULE_FALLBACK_POLICY"]
|
||||
|
||||
insert_disable_fallback_in_env()
|
||||
request.addfinalizer(remove_disable_fallback_from_env)
|
||||
|
||||
|
||||
def test_gelu():
|
||||
@torch.jit.script
|
||||
def bias_gelu(bias, y):
|
||||
x = bias + y
|
||||
|
|
@ -741,7 +756,6 @@ def test_InnerModuleCall():
|
|||
ctx.device = device
|
||||
ctx.inner = InnerModel(dim, device).to(device)
|
||||
if use_ort:
|
||||
enable_custom_autograd_function(ctx.inner)
|
||||
ctx.inner = ORTModule(ctx.inner)
|
||||
z = ctx.inner(x)
|
||||
return z
|
||||
|
|
@ -1076,9 +1090,6 @@ def test_non_differentiable_autograd_function():
|
|||
print("Ref:")
|
||||
print(y_ref)
|
||||
|
||||
from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support
|
||||
|
||||
enable_custom_autograd_support()
|
||||
m = ORTModule(m)
|
||||
|
||||
# Inferene mode.
|
||||
|
|
|
|||
|
|
@ -4,23 +4,24 @@
|
|||
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytest
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from onnxruntime.training.ortmodule import ORTModule, _fallback, ORTMODULE_TORCH_CPP_DIR
|
||||
from onnxruntime.training.ortmodule.torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed
|
||||
import _test_helpers
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from _orttraining_ortmodule_models import (
|
||||
NeuralNetSinglePositionalArgument,
|
||||
NeuralNetCustomClassOutput,
|
||||
MyCustomClassInputNet,
|
||||
MyCustomFunctionReluModel,
|
||||
NeuralNetCustomClassOutput,
|
||||
NeuralNetSinglePositionalArgument,
|
||||
)
|
||||
|
||||
from onnxruntime.training.ortmodule import ORTMODULE_TORCH_CPP_DIR, ORTModule, _fallback
|
||||
from onnxruntime.training.ortmodule.torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed
|
||||
|
||||
# PyTorch model definitions for tests
|
||||
|
||||
|
||||
|
|
@ -373,6 +374,7 @@ def test_ortmodule_fallback_init__torch_version(is_training, fallback_enabled, m
|
|||
# Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen
|
||||
|
||||
from packaging import version
|
||||
|
||||
from onnxruntime.training.ortmodule import MINIMUM_RUNTIME_PYTORCH_VERSION_STR
|
||||
|
||||
runtime_pytorch_version = version.parse(torch.__version__.split("+")[0])
|
||||
|
|
@ -486,6 +488,15 @@ def test_ortmodule_fallback_init__missing_cpp_extensions(
|
|||
def test_ortmodule_fallback_onnx_model__custom_autograd(
|
||||
is_training, fallback_enabled, matching_policy, persist_fallback
|
||||
):
|
||||
from onnxruntime.training.ortmodule._custom_autograd_function import (
|
||||
custom_autograd_function_enabler,
|
||||
enable_custom_autograd_support,
|
||||
)
|
||||
|
||||
# Disable the autograd support to test the fallback.
|
||||
old_state = custom_autograd_function_enabler.state
|
||||
enable_custom_autograd_support(False)
|
||||
|
||||
# is_training: True for torch.nn.Module training model, eval mode otherwise
|
||||
# fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
|
||||
# matching_policy: True matches FALLBACK_UNSUPPORTED_ONNX_MODEL policy to ORTModuleDeviceException exception.
|
||||
|
|
@ -536,6 +547,9 @@ def test_ortmodule_fallback_onnx_model__custom_autograd(
|
|||
_ = ort_model(x.mm(w1)).mm(w2)
|
||||
assert "There was an error while exporting the PyTorch model to ONNX" in str(ex_info.value)
|
||||
|
||||
# Restore the autograd support state.
|
||||
enable_custom_autograd_support(old_state)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_training,fallback_enabled,matching_policy,persist_fallback", list(itertools.product([True, False], repeat=4))
|
||||
|
|
|
|||
Loading…
Reference in a new issue