Enable PythonOp for --enable_training_torch_interop build (#12539)

* enable PythonOp by default when --enable_training_torch_interop is enabled during build

* clean up

* fix

* fix comment

* fix

* fix tests

* fix fallback test

* pylint format

* refine based on comments
This commit is contained in:
pengwa 2022-08-12 00:49:30 +08:00 committed by GitHub
parent b59ccbc75b
commit 24eab921be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 112 additions and 50 deletions

View file

@ -486,6 +486,13 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
pool.UnRegisterFunctions();
#endif
});
m.def("is_torch_interop_default_on", []() -> bool {
#ifdef ENABLE_TRAINING_TORCH_INTEROP
return true;
#else
return false;
#endif
});
py::class_<TrainingConfigurationResult> config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc");
config_result.def(py::init())

View file

@ -8,47 +8,81 @@ class Enabler(object):
def __init__(self):
self._state = False
# This is to indicate whether custom autograd support has been enabled or not (despite of the current state)
# in current process.
self._already_enabled = False
@property
def state(self):
return self._state
@property
def already_enabled(self):
return self._already_enabled
@state.setter
def state(self, val):
self._state = val
if self._already_enabled is False and val is True:
self._already_enabled = True
custom_autograd_function_enabler = Enabler()
# Legacy API to enable the custom autograd, keep its name with default value for compatibility.
def enable_custom_autograd_support(enable=True):
def enable_custom_autograd_support():
# Initialize static objects needed to run custom autograd.Function's.
from onnxruntime.capi._pybind_state import (
register_forward_runner,
register_backward_runner,
unregister_python_functions,
)
from torch.onnx import register_custom_op_symbolic
from ._custom_autograd_function_exporter import _export
from ._custom_autograd_function_runner import call_python_forward_function, call_python_backward_function
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils
import atexit
register_forward_runner(call_python_forward_function)
register_backward_runner(call_python_backward_function)
from torch.onnx import register_custom_op_symbolic, unregister_custom_op_symbolic
# Unregister all python functions automatically upon normal interpreter termination.
atexit.register(unregister_python_functions)
# Clear all gradient functions, to avoid a deadlock issue.
# Check the called function for more detailed comments.
atexit.register(torch_interop_utils.clear_all_grad_fns)
from onnxruntime.capi._pybind_state import (
register_backward_runner,
register_forward_runner,
unregister_python_functions,
)
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils
try:
# This is for the latest Pytorch nightly after this commit:
# https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec
register_custom_op_symbolic("prim::PythonOp", _export, 1)
except:
# This applies to Pytorch 1.9 and 1.9.1.
register_custom_op_symbolic("::prim_PythonOp", _export, 1)
from ._custom_autograd_function_exporter import _export
from ._custom_autograd_function_runner import call_python_backward_function, call_python_forward_function
custom_autograd_function_enabler.state = True
if enable is True:
if custom_autograd_function_enabler.already_enabled is False:
# Initialize static objects needed to run custom autograd.Function's.
register_forward_runner(call_python_forward_function)
register_backward_runner(call_python_backward_function)
# Unregister all python functions automatically upon normal interpreter termination.
atexit.register(unregister_python_functions)
# Clear all gradient functions, to avoid a deadlock issue.
# Check the called function for more detailed comments.
atexit.register(torch_interop_utils.clear_all_grad_fns)
try:
# This is for the latest Pytorch nightly after this commit:
# https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec
register_custom_op_symbolic("prim::PythonOp", _export, 1)
except:
# This applies to Pytorch 1.9 and 1.9.1.
register_custom_op_symbolic("::prim_PythonOp", _export, 1)
custom_autograd_function_enabler.state = True
else:
if custom_autograd_function_enabler.already_enabled is True:
# We don't need remove the registered runner because it won't be used if we disable the feature.
# But we need unregister the PythonOp custom operator function.
try:
# This is for the latest Pytorch nightly after this commit:
# https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec
unregister_custom_op_symbolic("prim::PythonOp", 1)
except:
# This applies to Pytorch 1.9 and 1.9.1.
unregister_custom_op_symbolic("::prim_PythonOp", 1)
custom_autograd_function_enabler.state = False
from onnxruntime.capi._pybind_state import is_torch_interop_default_on
# Enable the custom autograd by default when PythonOp backend support is enabled during build.
enable_custom_autograd_support(is_torch_interop_default_on())

View file

@ -22,7 +22,6 @@ from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from onnxruntime.training import ortmodule
from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils
from ._custom_autograd_function import custom_autograd_function_enabler
from ._custom_autograd_function_exporter import _post_process_after_export
from ._fallback import (
ORTModuleDeviceException,
@ -147,6 +146,8 @@ class GraphExecutionManager(GraphExecutionInterface):
self._run_symbolic_shape_infer = True
# PyTorch custom Autograd function support
from ._custom_autograd_function import custom_autograd_function_enabler
self._enable_custom_autograd_function = custom_autograd_function_enabler.state
self._input_info = None

View file

@ -1,17 +1,17 @@
import copy
import numpy as np
import os
import torch
import numpy as np
import torch
from numpy.testing import assert_allclose
from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer
from onnxruntime.training import orttrainer
try:
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support
from onnxruntime.training.ortmodule._graph_execution_manager_factory import GraphExecutionManagerFactory
from onnxruntime.training.ortmodule._fallback import ORTModuleInitException
from onnxruntime.training.ortmodule._graph_execution_manager_factory import GraphExecutionManagerFactory
except ImportError:
# Some pipelines do not contain ORTModule
pass
@ -231,10 +231,6 @@ def assert_values_are_close(input, other, rtol=1e-05, atol=1e-06):
assert False, err_msg
def enable_custom_autograd_function(module):
enable_custom_autograd_support()
def _run_model_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False):
if is_eval_mode:
model.eval()
@ -287,7 +283,6 @@ def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode=
with torch.no_grad():
model = copy.deepcopy(model)
model.to(device)
enable_custom_autograd_function(model)
model = ORTModule(model)
return _run_model_on_device(device, model, input_list, label_input, is_eval_mode, run_forward_twice)

View file

@ -1,6 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# pylint: disable=missing-docstring
# pylint: disable=C0103
from distutils.version import LooseVersion
import pytest
@ -22,7 +25,19 @@ def torch_version_lower_than(v):
return LooseVersion(torch.__version__) < LooseVersion(v)
def test_GeLU():
@pytest.fixture(scope="session", autouse=True)
def run_before_test_session(request):
def insert_disable_fallback_in_env():
os.environ["ORTMODULE_FALLBACK_POLICY"] = "FALLBACK_DISABLE"
def remove_disable_fallback_from_env():
del os.environ["ORTMODULE_FALLBACK_POLICY"]
insert_disable_fallback_in_env()
request.addfinalizer(remove_disable_fallback_from_env)
def test_gelu():
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
@ -741,7 +756,6 @@ def test_InnerModuleCall():
ctx.device = device
ctx.inner = InnerModel(dim, device).to(device)
if use_ort:
enable_custom_autograd_function(ctx.inner)
ctx.inner = ORTModule(ctx.inner)
z = ctx.inner(x)
return z
@ -1076,9 +1090,6 @@ def test_non_differentiable_autograd_function():
print("Ref:")
print(y_ref)
from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support
enable_custom_autograd_support()
m = ORTModule(m)
# Inferene mode.

View file

@ -4,23 +4,24 @@
import copy
import itertools
import os
import math
import numpy as np
import torch
import pytest
import os
import warnings
from onnxruntime.training.ortmodule import ORTModule, _fallback, ORTMODULE_TORCH_CPP_DIR
from onnxruntime.training.ortmodule.torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed
import _test_helpers
import numpy as np
import pytest
import torch
from _orttraining_ortmodule_models import (
NeuralNetSinglePositionalArgument,
NeuralNetCustomClassOutput,
MyCustomClassInputNet,
MyCustomFunctionReluModel,
NeuralNetCustomClassOutput,
NeuralNetSinglePositionalArgument,
)
from onnxruntime.training.ortmodule import ORTMODULE_TORCH_CPP_DIR, ORTModule, _fallback
from onnxruntime.training.ortmodule.torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed
# PyTorch model definitions for tests
@ -373,6 +374,7 @@ def test_ortmodule_fallback_init__torch_version(is_training, fallback_enabled, m
# Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen
from packaging import version
from onnxruntime.training.ortmodule import MINIMUM_RUNTIME_PYTORCH_VERSION_STR
runtime_pytorch_version = version.parse(torch.__version__.split("+")[0])
@ -486,6 +488,15 @@ def test_ortmodule_fallback_init__missing_cpp_extensions(
def test_ortmodule_fallback_onnx_model__custom_autograd(
is_training, fallback_enabled, matching_policy, persist_fallback
):
from onnxruntime.training.ortmodule._custom_autograd_function import (
custom_autograd_function_enabler,
enable_custom_autograd_support,
)
# Disable the autograd support to test the fallback.
old_state = custom_autograd_function_enabler.state
enable_custom_autograd_support(False)
# is_training: True for torch.nn.Module training model, eval mode otherwise
# fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
# matching_policy: True matches FALLBACK_UNSUPPORTED_ONNX_MODEL policy to ORTModuleDeviceException exception.
@ -536,6 +547,9 @@ def test_ortmodule_fallback_onnx_model__custom_autograd(
_ = ort_model(x.mm(w1)).mm(w2)
assert "There was an error while exporting the PyTorch model to ONNX" in str(ex_info.value)
# Restore the autograd support state.
enable_custom_autograd_support(old_state)
@pytest.mark.parametrize(
"is_training,fallback_enabled,matching_policy,persist_fallback", list(itertools.product([True, False], repeat=4))