mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
Enable all-or-nothing fallback by default (#8911)
This commit is contained in:
parent
1a34775fe9
commit
fe7f30aa14
5 changed files with 144 additions and 33 deletions
|
|
@ -24,8 +24,12 @@ MINIMUM_RUNTIME_PYTORCH_VERSION_STR = '1.8.1'
|
|||
TORCH_CPP_DIR = os.path.join(os.path.dirname(__file__),
|
||||
'torch_cpp_extensions')
|
||||
_FALLBACK_INIT_EXCEPTION = None
|
||||
ORTMODULE_FALLBACK_POLICY = _FallbackPolicy.FALLBACK_DISABLE
|
||||
ORTMODULE_FALLBACK_RETRY = True
|
||||
ORTMODULE_FALLBACK_POLICY = _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE |\
|
||||
_FallbackPolicy.FALLBACK_UNSUPPORTED_DATA |\
|
||||
_FallbackPolicy.FALLBACK_UNSUPPORTED_TORCH_MODEL |\
|
||||
_FallbackPolicy.FALLBACK_UNSUPPORTED_ONNX_MODEL |\
|
||||
_FallbackPolicy.FALLBACK_BAD_INITIALIZATION
|
||||
ORTMODULE_FALLBACK_RETRY = False
|
||||
|
||||
# Verify minimum PyTorch version is installed before proceding to ONNX Runtime initialization
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ class _FallbackManager(object):
|
|||
self.policy.is_set(policy) and \
|
||||
(policy.value in self._policy_exception_map and type(exception) in self._policy_exception_map[policy.value]):
|
||||
|
||||
if log_level <= _logger.LogLevel.WARNING:
|
||||
if log_level <= _logger.LogLevel.INFO:
|
||||
warnings.warn(
|
||||
f'Fallback for policy {policy.name} is pending.', UserWarning)
|
||||
self._exception = exception
|
||||
|
|
@ -208,7 +208,8 @@ class _FallbackManager(object):
|
|||
|
||||
if log_level <= _logger.LogLevel.WARNING:
|
||||
warnings.warn(
|
||||
(f'Fallback due to exception {type(self._exception)} was triggered. '
|
||||
(f'Fallback to PyTorch due to exception {type(self._exception)} was triggered. '
|
||||
'Report this issue with a minimal repro at https://www.github.com/microsoft/onnxruntime. '
|
||||
f'See details below:\n\n{get_exception_as_string(self._exception)}'), UserWarning)
|
||||
|
||||
# Pending fallbacks are resetted to enforce retries
|
||||
|
|
|
|||
|
|
@ -23,6 +23,19 @@ except Exception as e:
|
|||
pass
|
||||
raise
|
||||
|
||||
def is_all_or_nothing_fallback_enabled(model, policy=None):
|
||||
from onnxruntime.training.ortmodule import ORTMODULE_FALLBACK_POLICY
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
|
||||
|
||||
if not policy:
|
||||
policy = _FallbackPolicy.FALLBACK_DISABLE
|
||||
|
||||
fallback_on_env = policy in ORTMODULE_FALLBACK_POLICY
|
||||
fallback_on_model = False
|
||||
if model:
|
||||
fallback_on_model = policy in model._torch_module._execution_manager(is_training=True)._fallback_manager.policy or\
|
||||
policy in model._torch_module._execution_manager(is_training=False)._fallback_manager.policy
|
||||
return fallback_on_env or fallback_on_model
|
||||
|
||||
def assert_model_outputs(output_a, output_b, verbose=False, rtol=1e-7, atol=0):
|
||||
r"""Asserts whether output_a and output_b difference is within specified tolerance
|
||||
|
|
|
|||
|
|
@ -1543,15 +1543,26 @@ def test_named_tuple_return_value_module(device):
|
|||
def test_exception_raised_for_custom_class_return_value_module(device):
|
||||
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetCustomClassOutput(D_in, H, D_out).to(device)
|
||||
model = ORTModule(model)
|
||||
pt_model = NeuralNetCustomClassOutput(D_in, H, D_out).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
y = torch.randn(N, D_in, device=device)
|
||||
z = torch.randn(N, D_in, device=device)
|
||||
|
||||
with pytest.raises(_fallback.ORTModuleIOError) as runtime_error:
|
||||
model(x, y, z)
|
||||
assert 'ORTModule does not support the following model output type' in str(runtime_error.value)
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA):
|
||||
# Fallback
|
||||
pt_out = pt_model(x, y, z)
|
||||
ort_out = pt_model(x, y, z)
|
||||
# Assert that the output from torch is the same as the one from ORTModule
|
||||
_test_helpers.assert_values_are_close(pt_out.out1, ort_out.out1)
|
||||
_test_helpers.assert_values_are_close(pt_out.out2, ort_out.out2)
|
||||
_test_helpers.assert_values_are_close(pt_out.out3, ort_out.out3)
|
||||
else:
|
||||
# ORT backend
|
||||
with pytest.raises(_fallback.ORTModuleIOError) as runtime_error:
|
||||
model(x, y, z)
|
||||
assert 'ORTModule does not support the following model output type' in str(runtime_error.value)
|
||||
|
||||
def test_dynamic_axes_config():
|
||||
device = 'cuda'
|
||||
|
|
@ -1588,9 +1599,19 @@ def test_model_with_multiple_devices_cpu_cuda():
|
|||
|
||||
pt_model = MultipleDeviceModel()
|
||||
x = torch.randn(20, 10)
|
||||
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
|
||||
ort_model = ORTModule(pt_model)
|
||||
assert str(e.value) == 'ORTModule supports a single device per model'
|
||||
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
# Fallback
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
with pytest.raises(RuntimeError) as runtime_error:
|
||||
ort_model(x)
|
||||
assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value)
|
||||
else:
|
||||
# ORT backend
|
||||
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
|
||||
ort_model = ORTModule(pt_model)
|
||||
assert str(e.value) == 'ORTModule supports a single device per model'
|
||||
|
||||
def test_model_with_multiple_devices_to_to():
|
||||
class MultipleDeviceModel(torch.nn.Module):
|
||||
|
|
@ -1606,9 +1627,18 @@ def test_model_with_multiple_devices_to_to():
|
|||
|
||||
pt_model = MultipleDeviceModel()
|
||||
x = torch.randn(20, 10)
|
||||
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
|
||||
ort_model = ORTModule(pt_model)
|
||||
assert str(e.value) == 'ORTModule supports a single device per model'
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
# Fallback
|
||||
with pytest.raises(RuntimeError) as runtime_error:
|
||||
ort_model(x)
|
||||
assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value)
|
||||
else:
|
||||
# ORT backend
|
||||
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
|
||||
ort_model = ORTModule(pt_model)
|
||||
assert str(e.value) == 'ORTModule supports a single device per model'
|
||||
|
||||
def test_model_with_multiple_devices_to_cpu():
|
||||
class MultipleDeviceModel(torch.nn.Module):
|
||||
|
|
@ -1624,9 +1654,18 @@ def test_model_with_multiple_devices_to_cpu():
|
|||
|
||||
pt_model = MultipleDeviceModel()
|
||||
x = torch.randn(20, 10)
|
||||
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
|
||||
ort_model = ORTModule(pt_model)
|
||||
assert str(e.value) == 'ORTModule supports a single device per model'
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
# Fallback
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
with pytest.raises(RuntimeError) as runtime_error:
|
||||
ort_model(x)
|
||||
assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value)
|
||||
else:
|
||||
# ORT backend
|
||||
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
|
||||
ort_model = ORTModule(pt_model)
|
||||
assert str(e.value) == 'ORTModule supports a single device per model'
|
||||
|
||||
def test_model_with_multiple_devices_to_cuda():
|
||||
class MultipleDeviceModel(torch.nn.Module):
|
||||
|
|
@ -1642,10 +1681,18 @@ def test_model_with_multiple_devices_to_cuda():
|
|||
|
||||
pt_model = MultipleDeviceModel()
|
||||
x = torch.randn(20, 10)
|
||||
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
|
||||
ort_model = ORTModule(pt_model)
|
||||
|
||||
assert str(e.value) == 'ORTModule supports a single device per model'
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
# Fallback
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
with pytest.raises(RuntimeError) as runtime_error:
|
||||
ort_model(x)
|
||||
assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value)
|
||||
else:
|
||||
# ORT backend
|
||||
with pytest.raises(_fallback.ORTModuleFallbackException) as e:
|
||||
ort_model = ORTModule(pt_model)
|
||||
assert str(e.value) == 'ORTModule supports a single device per model'
|
||||
|
||||
@pytest.mark.parametrize("device", ['cuda', 'cuda:0', 'cuda:1', 'cuda:2'])
|
||||
def test_model_with_different_cuda_devices(device):
|
||||
|
|
@ -1823,9 +1870,17 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device):
|
|||
|
||||
# Now that the model has been exported, feed in data from device other than the model device
|
||||
x = torch.randn(N, D_in, device=data_device)
|
||||
with pytest.raises(_fallback.ORTModuleDeviceException) as runtime_error:
|
||||
ort_model(x)
|
||||
assert f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}." in str(runtime_error.value)
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy, ORTModuleDeviceException
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE):
|
||||
# Fallback
|
||||
with pytest.raises(RuntimeError) as runtime_error:
|
||||
ort_model(x)
|
||||
assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value)
|
||||
else:
|
||||
# ORT backend
|
||||
with pytest.raises(ORTModuleDeviceException) as runtime_error:
|
||||
ort_model(x)
|
||||
assert f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}." in str(runtime_error.value)
|
||||
|
||||
def test_forward_returns_none_type_as_output():
|
||||
class NeuralNetNoneTypeOutput(torch.nn.Module):
|
||||
|
|
@ -2990,11 +3045,21 @@ def test_hf_save_pretrained():
|
|||
|
||||
def test_input_with_string_exception():
|
||||
|
||||
model = MyStrNet()
|
||||
model = ORTModule(model)
|
||||
with pytest.raises(_fallback.ORTModuleIOError) as ex_info:
|
||||
_ = model(torch.randn(1, 2), 'hello')
|
||||
assert "ORTModule does not support the following model data type <class 'str'>" in str(ex_info.value)
|
||||
pt_model = MyStrNet()
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
x = torch.randn(1, 2)
|
||||
|
||||
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
|
||||
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA):
|
||||
# Fallback
|
||||
pt_out = pt_model(x, 'hello')
|
||||
ort_out = pt_model(x, 'hello')
|
||||
_test_helpers.assert_values_are_close(pt_out, ort_out)
|
||||
else:
|
||||
# ORT backend
|
||||
with pytest.raises(_fallback.ORTModuleIOError) as ex_info:
|
||||
_ = ort_model(x, 'hello')
|
||||
assert "ORTModule does not support the following model data type <class 'str'>" in str(ex_info.value)
|
||||
|
||||
def test_ortmodule_list_input():
|
||||
class ListNet(torch.nn.Module):
|
||||
|
|
@ -3212,15 +3277,15 @@ def test_ortmodule_gradient_accumulation_optimization_correctness():
|
|||
loss = prediction.sum()
|
||||
loss.backward()
|
||||
return loss.detach()
|
||||
|
||||
|
||||
def run_optim_step(optimizer):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
GA_steps = 2
|
||||
tgt_model.zero_grad()
|
||||
opt_model.zero_grad()
|
||||
|
||||
|
||||
for step in range(10):
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
tgt_loss = run_step(tgt_model, x)
|
||||
|
|
|
|||
|
|
@ -510,3 +510,31 @@ def test_ortmodule_fallback_onnx_model__missing_op(is_training, fallback_enabled
|
|||
# Initialize with fallback policy because Exception will happen during __init__
|
||||
_ = ort_model(x, y)
|
||||
assert "There was an error while exporting the PyTorch model to ONNX" in str(ex_info.value)
|
||||
|
||||
@pytest.mark.parametrize("is_training,persist_fallback",
|
||||
list(itertools.product([True,False],repeat=2)))
|
||||
def test_ortmodule_fallback_warn_message(is_training, persist_fallback):
|
||||
# is_training: True for torch.nn.Module training model, eval mode otherwise
|
||||
|
||||
policy = 'FALLBACK_UNSUPPORTED_DEVICE'
|
||||
os.environ['ORTMODULE_FALLBACK_POLICY'] = policy
|
||||
os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback)
|
||||
|
||||
data_device = 'cuda'
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
|
||||
pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
pt_model.train(is_training)
|
||||
ort_model.train(is_training)
|
||||
# For initial model export, use same device for data and model so that PyTorch model can be traced during export
|
||||
_ = ort_model(torch.randn(N, D_in))
|
||||
|
||||
# Use data in different device for testing
|
||||
inputs = torch.randn(N, D_in, device=data_device)
|
||||
|
||||
for _ in range(3):
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.warns(UserWarning) as warning_record:
|
||||
ort_model(inputs)
|
||||
assert "Fallback to PyTorch due to exception" in str(warning_record[0].message.args[0])
|
||||
|
|
|
|||
Loading…
Reference in a new issue