diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 8974edf172..ae6111d48d 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -486,6 +486,13 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn pool.UnRegisterFunctions(); #endif }); + m.def("is_torch_interop_default_on", []() -> bool { +#ifdef ENABLE_TRAINING_TORCH_INTEROP + return true; +#else + return false; +#endif + }); py::class_ config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc"); config_result.def(py::init()) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py index 6f3c7c7d29..19d08d68a6 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py @@ -8,47 +8,81 @@ class Enabler(object): def __init__(self): self._state = False + # This is to indicate whether custom autograd support has been enabled or not (despite of the current state) + # in current process. + self._already_enabled = False + @property def state(self): return self._state + @property + def already_enabled(self): + return self._already_enabled + @state.setter def state(self, val): self._state = val + if self._already_enabled is False and val is True: + self._already_enabled = True custom_autograd_function_enabler = Enabler() +# Legacy API to enable the custom autograd, keep its name with default value for compatibility. +def enable_custom_autograd_support(enable=True): -def enable_custom_autograd_support(): - # Initialize static objects needed to run custom autograd.Function's. - - from onnxruntime.capi._pybind_state import ( - register_forward_runner, - register_backward_runner, - unregister_python_functions, - ) - from torch.onnx import register_custom_op_symbolic - from ._custom_autograd_function_exporter import _export - from ._custom_autograd_function_runner import call_python_forward_function, call_python_backward_function - from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils import atexit - register_forward_runner(call_python_forward_function) - register_backward_runner(call_python_backward_function) + from torch.onnx import register_custom_op_symbolic, unregister_custom_op_symbolic - # Unregister all python functions automatically upon normal interpreter termination. - atexit.register(unregister_python_functions) - # Clear all gradient functions, to avoid a deadlock issue. - # Check the called function for more detailed comments. - atexit.register(torch_interop_utils.clear_all_grad_fns) + from onnxruntime.capi._pybind_state import ( + register_backward_runner, + register_forward_runner, + unregister_python_functions, + ) + from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils - try: - # This is for the latest Pytorch nightly after this commit: - # https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec - register_custom_op_symbolic("prim::PythonOp", _export, 1) - except: - # This applies to Pytorch 1.9 and 1.9.1. - register_custom_op_symbolic("::prim_PythonOp", _export, 1) + from ._custom_autograd_function_exporter import _export + from ._custom_autograd_function_runner import call_python_backward_function, call_python_forward_function - custom_autograd_function_enabler.state = True + if enable is True: + if custom_autograd_function_enabler.already_enabled is False: + # Initialize static objects needed to run custom autograd.Function's. + register_forward_runner(call_python_forward_function) + register_backward_runner(call_python_backward_function) + + # Unregister all python functions automatically upon normal interpreter termination. + atexit.register(unregister_python_functions) + # Clear all gradient functions, to avoid a deadlock issue. + # Check the called function for more detailed comments. + atexit.register(torch_interop_utils.clear_all_grad_fns) + + try: + # This is for the latest Pytorch nightly after this commit: + # https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec + register_custom_op_symbolic("prim::PythonOp", _export, 1) + except: + # This applies to Pytorch 1.9 and 1.9.1. + register_custom_op_symbolic("::prim_PythonOp", _export, 1) + + custom_autograd_function_enabler.state = True + else: + if custom_autograd_function_enabler.already_enabled is True: + # We don't need remove the registered runner because it won't be used if we disable the feature. + # But we need unregister the PythonOp custom operator function. + try: + # This is for the latest Pytorch nightly after this commit: + # https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec + unregister_custom_op_symbolic("prim::PythonOp", 1) + except: + # This applies to Pytorch 1.9 and 1.9.1. + unregister_custom_op_symbolic("::prim_PythonOp", 1) + + custom_autograd_function_enabler.state = False + + +from onnxruntime.capi._pybind_state import is_torch_interop_default_on + +# Enable the custom autograd by default when PythonOp backend support is enabled during build. +enable_custom_autograd_support(is_torch_interop_default_on()) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 130ef6691b..0e194caf87 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -22,7 +22,6 @@ from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from onnxruntime.training import ortmodule from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils -from ._custom_autograd_function import custom_autograd_function_enabler from ._custom_autograd_function_exporter import _post_process_after_export from ._fallback import ( ORTModuleDeviceException, @@ -147,6 +146,8 @@ class GraphExecutionManager(GraphExecutionInterface): self._run_symbolic_shape_infer = True # PyTorch custom Autograd function support + from ._custom_autograd_function import custom_autograd_function_enabler + self._enable_custom_autograd_function = custom_autograd_function_enabler.state self._input_info = None diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index e888aade9f..95c3b58521 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -1,17 +1,17 @@ import copy -import numpy as np import os -import torch +import numpy as np +import torch from numpy.testing import assert_allclose + from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer from onnxruntime.training import orttrainer try: from onnxruntime.training.ortmodule import ORTModule - from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support - from onnxruntime.training.ortmodule._graph_execution_manager_factory import GraphExecutionManagerFactory from onnxruntime.training.ortmodule._fallback import ORTModuleInitException + from onnxruntime.training.ortmodule._graph_execution_manager_factory import GraphExecutionManagerFactory except ImportError: # Some pipelines do not contain ORTModule pass @@ -231,10 +231,6 @@ def assert_values_are_close(input, other, rtol=1e-05, atol=1e-06): assert False, err_msg -def enable_custom_autograd_function(module): - enable_custom_autograd_support() - - def _run_model_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False): if is_eval_mode: model.eval() @@ -287,7 +283,6 @@ def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode= with torch.no_grad(): model = copy.deepcopy(model) model.to(device) - enable_custom_autograd_function(model) model = ORTModule(model) return _run_model_on_device(device, model, input_list, label_input, is_eval_mode, run_forward_twice) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index d0de69a431..99d7f99f93 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -1,6 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +# pylint: disable=missing-docstring +# pylint: disable=C0103 + from distutils.version import LooseVersion import pytest @@ -22,7 +25,19 @@ def torch_version_lower_than(v): return LooseVersion(torch.__version__) < LooseVersion(v) -def test_GeLU(): +@pytest.fixture(scope="session", autouse=True) +def run_before_test_session(request): + def insert_disable_fallback_in_env(): + os.environ["ORTMODULE_FALLBACK_POLICY"] = "FALLBACK_DISABLE" + + def remove_disable_fallback_from_env(): + del os.environ["ORTMODULE_FALLBACK_POLICY"] + + insert_disable_fallback_in_env() + request.addfinalizer(remove_disable_fallback_from_env) + + +def test_gelu(): @torch.jit.script def bias_gelu(bias, y): x = bias + y @@ -741,7 +756,6 @@ def test_InnerModuleCall(): ctx.device = device ctx.inner = InnerModel(dim, device).to(device) if use_ort: - enable_custom_autograd_function(ctx.inner) ctx.inner = ORTModule(ctx.inner) z = ctx.inner(x) return z @@ -1076,9 +1090,6 @@ def test_non_differentiable_autograd_function(): print("Ref:") print(y_ref) - from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support - - enable_custom_autograd_support() m = ORTModule(m) # Inferene mode. diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py index cc3c2e47a4..1ee2cec568 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py @@ -4,23 +4,24 @@ import copy import itertools -import os import math -import numpy as np -import torch -import pytest +import os import warnings -from onnxruntime.training.ortmodule import ORTModule, _fallback, ORTMODULE_TORCH_CPP_DIR -from onnxruntime.training.ortmodule.torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed import _test_helpers +import numpy as np +import pytest +import torch from _orttraining_ortmodule_models import ( - NeuralNetSinglePositionalArgument, - NeuralNetCustomClassOutput, MyCustomClassInputNet, MyCustomFunctionReluModel, + NeuralNetCustomClassOutput, + NeuralNetSinglePositionalArgument, ) +from onnxruntime.training.ortmodule import ORTMODULE_TORCH_CPP_DIR, ORTModule, _fallback +from onnxruntime.training.ortmodule.torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed + # PyTorch model definitions for tests @@ -373,6 +374,7 @@ def test_ortmodule_fallback_init__torch_version(is_training, fallback_enabled, m # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen from packaging import version + from onnxruntime.training.ortmodule import MINIMUM_RUNTIME_PYTORCH_VERSION_STR runtime_pytorch_version = version.parse(torch.__version__.split("+")[0]) @@ -486,6 +488,15 @@ def test_ortmodule_fallback_init__missing_cpp_extensions( def test_ortmodule_fallback_onnx_model__custom_autograd( is_training, fallback_enabled, matching_policy, persist_fallback ): + from onnxruntime.training.ortmodule._custom_autograd_function import ( + custom_autograd_function_enabler, + enable_custom_autograd_support, + ) + + # Disable the autograd support to test the fallback. + old_state = custom_autograd_function_enabler.state + enable_custom_autograd_support(False) + # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_ONNX_MODEL policy to ORTModuleDeviceException exception. @@ -536,6 +547,9 @@ def test_ortmodule_fallback_onnx_model__custom_autograd( _ = ort_model(x.mm(w1)).mm(w2) assert "There was an error while exporting the PyTorch model to ONNX" in str(ex_info.value) + # Restore the autograd support state. + enable_custom_autograd_support(old_state) + @pytest.mark.parametrize( "is_training,fallback_enabled,matching_policy,persist_fallback", list(itertools.product([True, False], repeat=4))