From fe7f30aa142fff3163b4fc170afa206e1f5fba80 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Thu, 2 Sep 2021 13:45:14 -0400 Subject: [PATCH] Enable all-or-nothing fallback by default (#8911) --- .../python/training/ortmodule/__init__.py | 8 +- .../python/training/ortmodule/_fallback.py | 5 +- .../orttraining/test/python/_test_helpers.py | 13 ++ .../python/orttraining_test_ortmodule_api.py | 123 +++++++++++++----- .../orttraining_test_ortmodule_fallback.py | 28 ++++ 5 files changed, 144 insertions(+), 33 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index ef02286ed3..742c92295a 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -24,8 +24,12 @@ MINIMUM_RUNTIME_PYTORCH_VERSION_STR = '1.8.1' TORCH_CPP_DIR = os.path.join(os.path.dirname(__file__), 'torch_cpp_extensions') _FALLBACK_INIT_EXCEPTION = None -ORTMODULE_FALLBACK_POLICY = _FallbackPolicy.FALLBACK_DISABLE -ORTMODULE_FALLBACK_RETRY = True +ORTMODULE_FALLBACK_POLICY = _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE |\ + _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA |\ + _FallbackPolicy.FALLBACK_UNSUPPORTED_TORCH_MODEL |\ + _FallbackPolicy.FALLBACK_UNSUPPORTED_ONNX_MODEL |\ + _FallbackPolicy.FALLBACK_BAD_INITIALIZATION +ORTMODULE_FALLBACK_RETRY = False # Verify minimum PyTorch version is installed before proceding to ONNX Runtime initialization try: diff --git a/orttraining/orttraining/python/training/ortmodule/_fallback.py b/orttraining/orttraining/python/training/ortmodule/_fallback.py index 9810e19d93..53b0e25009 100644 --- a/orttraining/orttraining/python/training/ortmodule/_fallback.py +++ b/orttraining/orttraining/python/training/ortmodule/_fallback.py @@ -178,7 +178,7 @@ class _FallbackManager(object): self.policy.is_set(policy) and \ (policy.value in self._policy_exception_map and type(exception) in self._policy_exception_map[policy.value]): - if log_level <= _logger.LogLevel.WARNING: + if log_level <= _logger.LogLevel.INFO: warnings.warn( f'Fallback for policy {policy.name} is pending.', UserWarning) self._exception = exception @@ -208,7 +208,8 @@ class _FallbackManager(object): if log_level <= _logger.LogLevel.WARNING: warnings.warn( - (f'Fallback due to exception {type(self._exception)} was triggered. ' + (f'Fallback to PyTorch due to exception {type(self._exception)} was triggered. ' + 'Report this issue with a minimal repro at https://www.github.com/microsoft/onnxruntime. ' f'See details below:\n\n{get_exception_as_string(self._exception)}'), UserWarning) # Pending fallbacks are resetted to enforce retries diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index cf4f370e66..51050f7b4d 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -23,6 +23,19 @@ except Exception as e: pass raise +def is_all_or_nothing_fallback_enabled(model, policy=None): + from onnxruntime.training.ortmodule import ORTMODULE_FALLBACK_POLICY + from onnxruntime.training.ortmodule._fallback import _FallbackPolicy + + if not policy: + policy = _FallbackPolicy.FALLBACK_DISABLE + + fallback_on_env = policy in ORTMODULE_FALLBACK_POLICY + fallback_on_model = False + if model: + fallback_on_model = policy in model._torch_module._execution_manager(is_training=True)._fallback_manager.policy or\ + policy in model._torch_module._execution_manager(is_training=False)._fallback_manager.policy + return fallback_on_env or fallback_on_model def assert_model_outputs(output_a, output_b, verbose=False, rtol=1e-7, atol=0): r"""Asserts whether output_a and output_b difference is within specified tolerance diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 37acc2b8d3..d08c3924b9 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1543,15 +1543,26 @@ def test_named_tuple_return_value_module(device): def test_exception_raised_for_custom_class_return_value_module(device): N, D_in, H, D_out = 64, 784, 500, 10 - model = NeuralNetCustomClassOutput(D_in, H, D_out).to(device) - model = ORTModule(model) + pt_model = NeuralNetCustomClassOutput(D_in, H, D_out).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) x = torch.randn(N, D_in, device=device) y = torch.randn(N, D_in, device=device) z = torch.randn(N, D_in, device=device) - with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: - model(x, y, z) - assert 'ORTModule does not support the following model output type' in str(runtime_error.value) + from onnxruntime.training.ortmodule._fallback import _FallbackPolicy + if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA): + # Fallback + pt_out = pt_model(x, y, z) + ort_out = pt_model(x, y, z) + # Assert that the output from torch is the same as the one from ORTModule + _test_helpers.assert_values_are_close(pt_out.out1, ort_out.out1) + _test_helpers.assert_values_are_close(pt_out.out2, ort_out.out2) + _test_helpers.assert_values_are_close(pt_out.out3, ort_out.out3) + else: + # ORT backend + with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: + model(x, y, z) + assert 'ORTModule does not support the following model output type' in str(runtime_error.value) def test_dynamic_axes_config(): device = 'cuda' @@ -1588,9 +1599,19 @@ def test_model_with_multiple_devices_cpu_cuda(): pt_model = MultipleDeviceModel() x = torch.randn(20, 10) - with pytest.raises(_fallback.ORTModuleFallbackException) as e: - ort_model = ORTModule(pt_model) - assert str(e.value) == 'ORTModule supports a single device per model' + + from onnxruntime.training.ortmodule._fallback import _FallbackPolicy + if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE): + # Fallback + ort_model = ORTModule(copy.deepcopy(pt_model)) + with pytest.raises(RuntimeError) as runtime_error: + ort_model(x) + assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value) + else: + # ORT backend + with pytest.raises(_fallback.ORTModuleFallbackException) as e: + ort_model = ORTModule(pt_model) + assert str(e.value) == 'ORTModule supports a single device per model' def test_model_with_multiple_devices_to_to(): class MultipleDeviceModel(torch.nn.Module): @@ -1606,9 +1627,18 @@ def test_model_with_multiple_devices_to_to(): pt_model = MultipleDeviceModel() x = torch.randn(20, 10) - with pytest.raises(_fallback.ORTModuleFallbackException) as e: - ort_model = ORTModule(pt_model) - assert str(e.value) == 'ORTModule supports a single device per model' + ort_model = ORTModule(copy.deepcopy(pt_model)) + from onnxruntime.training.ortmodule._fallback import _FallbackPolicy + if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE): + # Fallback + with pytest.raises(RuntimeError) as runtime_error: + ort_model(x) + assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value) + else: + # ORT backend + with pytest.raises(_fallback.ORTModuleFallbackException) as e: + ort_model = ORTModule(pt_model) + assert str(e.value) == 'ORTModule supports a single device per model' def test_model_with_multiple_devices_to_cpu(): class MultipleDeviceModel(torch.nn.Module): @@ -1624,9 +1654,18 @@ def test_model_with_multiple_devices_to_cpu(): pt_model = MultipleDeviceModel() x = torch.randn(20, 10) - with pytest.raises(_fallback.ORTModuleFallbackException) as e: - ort_model = ORTModule(pt_model) - assert str(e.value) == 'ORTModule supports a single device per model' + from onnxruntime.training.ortmodule._fallback import _FallbackPolicy + if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE): + # Fallback + ort_model = ORTModule(copy.deepcopy(pt_model)) + with pytest.raises(RuntimeError) as runtime_error: + ort_model(x) + assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value) + else: + # ORT backend + with pytest.raises(_fallback.ORTModuleFallbackException) as e: + ort_model = ORTModule(pt_model) + assert str(e.value) == 'ORTModule supports a single device per model' def test_model_with_multiple_devices_to_cuda(): class MultipleDeviceModel(torch.nn.Module): @@ -1642,10 +1681,18 @@ def test_model_with_multiple_devices_to_cuda(): pt_model = MultipleDeviceModel() x = torch.randn(20, 10) - with pytest.raises(_fallback.ORTModuleFallbackException) as e: - ort_model = ORTModule(pt_model) - - assert str(e.value) == 'ORTModule supports a single device per model' + from onnxruntime.training.ortmodule._fallback import _FallbackPolicy + if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE): + # Fallback + ort_model = ORTModule(copy.deepcopy(pt_model)) + with pytest.raises(RuntimeError) as runtime_error: + ort_model(x) + assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value) + else: + # ORT backend + with pytest.raises(_fallback.ORTModuleFallbackException) as e: + ort_model = ORTModule(pt_model) + assert str(e.value) == 'ORTModule supports a single device per model' @pytest.mark.parametrize("device", ['cuda', 'cuda:0', 'cuda:1', 'cuda:2']) def test_model_with_different_cuda_devices(device): @@ -1823,9 +1870,17 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device): # Now that the model has been exported, feed in data from device other than the model device x = torch.randn(N, D_in, device=data_device) - with pytest.raises(_fallback.ORTModuleDeviceException) as runtime_error: - ort_model(x) - assert f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}." in str(runtime_error.value) + from onnxruntime.training.ortmodule._fallback import _FallbackPolicy, ORTModuleDeviceException + if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE): + # Fallback + with pytest.raises(RuntimeError) as runtime_error: + ort_model(x) + assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value) + else: + # ORT backend + with pytest.raises(ORTModuleDeviceException) as runtime_error: + ort_model(x) + assert f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}." in str(runtime_error.value) def test_forward_returns_none_type_as_output(): class NeuralNetNoneTypeOutput(torch.nn.Module): @@ -2990,11 +3045,21 @@ def test_hf_save_pretrained(): def test_input_with_string_exception(): - model = MyStrNet() - model = ORTModule(model) - with pytest.raises(_fallback.ORTModuleIOError) as ex_info: - _ = model(torch.randn(1, 2), 'hello') - assert "ORTModule does not support the following model data type " in str(ex_info.value) + pt_model = MyStrNet() + ort_model = ORTModule(copy.deepcopy(pt_model)) + x = torch.randn(1, 2) + + from onnxruntime.training.ortmodule._fallback import _FallbackPolicy + if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA): + # Fallback + pt_out = pt_model(x, 'hello') + ort_out = pt_model(x, 'hello') + _test_helpers.assert_values_are_close(pt_out, ort_out) + else: + # ORT backend + with pytest.raises(_fallback.ORTModuleIOError) as ex_info: + _ = ort_model(x, 'hello') + assert "ORTModule does not support the following model data type " in str(ex_info.value) def test_ortmodule_list_input(): class ListNet(torch.nn.Module): @@ -3212,15 +3277,15 @@ def test_ortmodule_gradient_accumulation_optimization_correctness(): loss = prediction.sum() loss.backward() return loss.detach() - + def run_optim_step(optimizer): optimizer.step() optimizer.zero_grad() - + GA_steps = 2 tgt_model.zero_grad() opt_model.zero_grad() - + for step in range(10): x = torch.randn(N, D_in, device=device) tgt_loss = run_step(tgt_model, x) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py index db55762117..ee6ce75ccc 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py @@ -510,3 +510,31 @@ def test_ortmodule_fallback_onnx_model__missing_op(is_training, fallback_enabled # Initialize with fallback policy because Exception will happen during __init__ _ = ort_model(x, y) assert "There was an error while exporting the PyTorch model to ONNX" in str(ex_info.value) + +@pytest.mark.parametrize("is_training,persist_fallback", + list(itertools.product([True,False],repeat=2))) +def test_ortmodule_fallback_warn_message(is_training, persist_fallback): + # is_training: True for torch.nn.Module training model, eval mode otherwise + + policy = 'FALLBACK_UNSUPPORTED_DEVICE' + os.environ['ORTMODULE_FALLBACK_POLICY'] = policy + os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) + + data_device = 'cuda' + N, D_in, H, D_out = 64, 784, 500, 10 + + pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out) + ort_model = ORTModule(copy.deepcopy(pt_model)) + pt_model.train(is_training) + ort_model.train(is_training) + # For initial model export, use same device for data and model so that PyTorch model can be traced during export + _ = ort_model(torch.randn(N, D_in)) + + # Use data in different device for testing + inputs = torch.randn(N, D_in, device=data_device) + + for _ in range(3): + with pytest.raises(RuntimeError): + with pytest.warns(UserWarning) as warning_record: + ort_model(inputs) + assert "Fallback to PyTorch due to exception" in str(warning_record[0].message.args[0])