Enable all-or-nothing fallback by default (#8911)

This commit is contained in:
Thiago Crepaldi 2021-09-02 13:45:14 -04:00 committed by GitHub
parent 1a34775fe9
commit fe7f30aa14
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 144 additions and 33 deletions

View file

@ -24,8 +24,12 @@ MINIMUM_RUNTIME_PYTORCH_VERSION_STR = '1.8.1'
TORCH_CPP_DIR = os.path.join(os.path.dirname(__file__),
'torch_cpp_extensions')
_FALLBACK_INIT_EXCEPTION = None
ORTMODULE_FALLBACK_POLICY = _FallbackPolicy.FALLBACK_DISABLE
ORTMODULE_FALLBACK_RETRY = True
ORTMODULE_FALLBACK_POLICY = _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE |\
_FallbackPolicy.FALLBACK_UNSUPPORTED_DATA |\
_FallbackPolicy.FALLBACK_UNSUPPORTED_TORCH_MODEL |\
_FallbackPolicy.FALLBACK_UNSUPPORTED_ONNX_MODEL |\
_FallbackPolicy.FALLBACK_BAD_INITIALIZATION
ORTMODULE_FALLBACK_RETRY = False
# Verify minimum PyTorch version is installed before proceding to ONNX Runtime initialization
try:

View file

@ -178,7 +178,7 @@ class _FallbackManager(object):
self.policy.is_set(policy) and \
(policy.value in self._policy_exception_map and type(exception) in self._policy_exception_map[policy.value]):
if log_level <= _logger.LogLevel.WARNING:
if log_level <= _logger.LogLevel.INFO:
warnings.warn(
f'Fallback for policy {policy.name} is pending.', UserWarning)
self._exception = exception
@ -208,7 +208,8 @@ class _FallbackManager(object):
if log_level <= _logger.LogLevel.WARNING:
warnings.warn(
(f'Fallback due to exception {type(self._exception)} was triggered. '
(f'Fallback to PyTorch due to exception {type(self._exception)} was triggered. '
'Report this issue with a minimal repro at https://www.github.com/microsoft/onnxruntime. '
f'See details below:\n\n{get_exception_as_string(self._exception)}'), UserWarning)
# Pending fallbacks are resetted to enforce retries

View file

@ -23,6 +23,19 @@ except Exception as e:
pass
raise
def is_all_or_nothing_fallback_enabled(model, policy=None):
from onnxruntime.training.ortmodule import ORTMODULE_FALLBACK_POLICY
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
if not policy:
policy = _FallbackPolicy.FALLBACK_DISABLE
fallback_on_env = policy in ORTMODULE_FALLBACK_POLICY
fallback_on_model = False
if model:
fallback_on_model = policy in model._torch_module._execution_manager(is_training=True)._fallback_manager.policy or\
policy in model._torch_module._execution_manager(is_training=False)._fallback_manager.policy
return fallback_on_env or fallback_on_model
def assert_model_outputs(output_a, output_b, verbose=False, rtol=1e-7, atol=0):
r"""Asserts whether output_a and output_b difference is within specified tolerance

View file

@ -1543,15 +1543,26 @@ def test_named_tuple_return_value_module(device):
def test_exception_raised_for_custom_class_return_value_module(device):
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetCustomClassOutput(D_in, H, D_out).to(device)
model = ORTModule(model)
pt_model = NeuralNetCustomClassOutput(D_in, H, D_out).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_in, device=device)
z = torch.randn(N, D_in, device=device)
with pytest.raises(_fallback.ORTModuleIOError) as runtime_error:
model(x, y, z)
assert 'ORTModule does not support the following model output type' in str(runtime_error.value)
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA):
# Fallback
pt_out = pt_model(x, y, z)
ort_out = pt_model(x, y, z)
# Assert that the output from torch is the same as the one from ORTModule
_test_helpers.assert_values_are_close(pt_out.out1, ort_out.out1)
_test_helpers.assert_values_are_close(pt_out.out2, ort_out.out2)
_test_helpers.assert_values_are_close(pt_out.out3, ort_out.out3)
else:
# ORT backend
with pytest.raises(_fallback.ORTModuleIOError) as runtime_error:
model(x, y, z)
assert 'ORTModule does not support the following model output type' in str(runtime_error.value)
def test_dynamic_axes_config():
device = 'cuda'
@ -1588,9 +1599,19 @@ def test_model_with_multiple_devices_cpu_cuda():
pt_model = MultipleDeviceModel()
x = torch.randn(20, 10)
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
ort_model = ORTModule(pt_model)
assert str(e.value) == 'ORTModule supports a single device per model'
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
# Fallback
ort_model = ORTModule(copy.deepcopy(pt_model))
with pytest.raises(RuntimeError) as runtime_error:
ort_model(x)
assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value)
else:
# ORT backend
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
ort_model = ORTModule(pt_model)
assert str(e.value) == 'ORTModule supports a single device per model'
def test_model_with_multiple_devices_to_to():
class MultipleDeviceModel(torch.nn.Module):
@ -1606,9 +1627,18 @@ def test_model_with_multiple_devices_to_to():
pt_model = MultipleDeviceModel()
x = torch.randn(20, 10)
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
ort_model = ORTModule(pt_model)
assert str(e.value) == 'ORTModule supports a single device per model'
ort_model = ORTModule(copy.deepcopy(pt_model))
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
# Fallback
with pytest.raises(RuntimeError) as runtime_error:
ort_model(x)
assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value)
else:
# ORT backend
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
ort_model = ORTModule(pt_model)
assert str(e.value) == 'ORTModule supports a single device per model'
def test_model_with_multiple_devices_to_cpu():
class MultipleDeviceModel(torch.nn.Module):
@ -1624,9 +1654,18 @@ def test_model_with_multiple_devices_to_cpu():
pt_model = MultipleDeviceModel()
x = torch.randn(20, 10)
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
ort_model = ORTModule(pt_model)
assert str(e.value) == 'ORTModule supports a single device per model'
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
# Fallback
ort_model = ORTModule(copy.deepcopy(pt_model))
with pytest.raises(RuntimeError) as runtime_error:
ort_model(x)
assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value)
else:
# ORT backend
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
ort_model = ORTModule(pt_model)
assert str(e.value) == 'ORTModule supports a single device per model'
def test_model_with_multiple_devices_to_cuda():
class MultipleDeviceModel(torch.nn.Module):
@ -1642,10 +1681,18 @@ def test_model_with_multiple_devices_to_cuda():
pt_model = MultipleDeviceModel()
x = torch.randn(20, 10)
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
ort_model = ORTModule(pt_model)
assert str(e.value) == 'ORTModule supports a single device per model'
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
# Fallback
ort_model = ORTModule(copy.deepcopy(pt_model))
with pytest.raises(RuntimeError) as runtime_error:
ort_model(x)
assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value)
else:
# ORT backend
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
ort_model = ORTModule(pt_model)
assert str(e.value) == 'ORTModule supports a single device per model'
@pytest.mark.parametrize("device", ['cuda', 'cuda:0', 'cuda:1', 'cuda:2'])
def test_model_with_different_cuda_devices(device):
@ -1823,9 +1870,17 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device):
# Now that the model has been exported, feed in data from device other than the model device
x = torch.randn(N, D_in, device=data_device)
with pytest.raises(_fallback.ORTModuleDeviceException) as runtime_error:
ort_model(x)
assert f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}." in str(runtime_error.value)
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy, ORTModuleDeviceException
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
# Fallback
with pytest.raises(RuntimeError) as runtime_error:
ort_model(x)
assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value)
else:
# ORT backend
with pytest.raises(ORTModuleDeviceException) as runtime_error:
ort_model(x)
assert f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}." in str(runtime_error.value)
def test_forward_returns_none_type_as_output():
class NeuralNetNoneTypeOutput(torch.nn.Module):
@ -2990,11 +3045,21 @@ def test_hf_save_pretrained():
def test_input_with_string_exception():
model = MyStrNet()
model = ORTModule(model)
with pytest.raises(_fallback.ORTModuleIOError) as ex_info:
_ = model(torch.randn(1, 2), 'hello')
assert "ORTModule does not support the following model data type <class 'str'>" in str(ex_info.value)
pt_model = MyStrNet()
ort_model = ORTModule(copy.deepcopy(pt_model))
x = torch.randn(1, 2)
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA):
# Fallback
pt_out = pt_model(x, 'hello')
ort_out = pt_model(x, 'hello')
_test_helpers.assert_values_are_close(pt_out, ort_out)
else:
# ORT backend
with pytest.raises(_fallback.ORTModuleIOError) as ex_info:
_ = ort_model(x, 'hello')
assert "ORTModule does not support the following model data type <class 'str'>" in str(ex_info.value)
def test_ortmodule_list_input():
class ListNet(torch.nn.Module):
@ -3212,15 +3277,15 @@ def test_ortmodule_gradient_accumulation_optimization_correctness():
loss = prediction.sum()
loss.backward()
return loss.detach()
def run_optim_step(optimizer):
optimizer.step()
optimizer.zero_grad()
GA_steps = 2
tgt_model.zero_grad()
opt_model.zero_grad()
for step in range(10):
x = torch.randn(N, D_in, device=device)
tgt_loss = run_step(tgt_model, x)

View file

@ -510,3 +510,31 @@ def test_ortmodule_fallback_onnx_model__missing_op(is_training, fallback_enabled
# Initialize with fallback policy because Exception will happen during __init__
_ = ort_model(x, y)
assert "There was an error while exporting the PyTorch model to ONNX" in str(ex_info.value)
@pytest.mark.parametrize("is_training,persist_fallback",
list(itertools.product([True,False],repeat=2)))
def test_ortmodule_fallback_warn_message(is_training, persist_fallback):
# is_training: True for torch.nn.Module training model, eval mode otherwise
policy = 'FALLBACK_UNSUPPORTED_DEVICE'
os.environ['ORTMODULE_FALLBACK_POLICY'] = policy
os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback)
data_device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
ort_model = ORTModule(copy.deepcopy(pt_model))
pt_model.train(is_training)
ort_model.train(is_training)
# For initial model export, use same device for data and model so that PyTorch model can be traced during export
_ = ort_model(torch.randn(N, D_in))
# Use data in different device for testing
inputs = torch.randn(N, D_in, device=data_device)
for _ in range(3):
with pytest.raises(RuntimeError):
with pytest.warns(UserWarning) as warning_record:
ort_model(inputs)
assert "Fallback to PyTorch due to exception" in str(warning_record[0].message.args[0])