From 20eaed43e513633ce46d9e7d11db3db4d0d77da4 Mon Sep 17 00:00:00 2001 From: baijumeswani Date: Tue, 19 Oct 2021 16:34:47 -0700 Subject: [PATCH] Ignore all string inputs to ORTModule AB#1310803 (#9344) --- .../python/training/ortmodule/_io.py | 20 +++++++++++----- .../python/training/ortmodule/_utils.py | 4 ++++ .../python/_orttraining_ortmodule_models.py | 8 +++---- .../python/orttraining_test_ortmodule_api.py | 21 +++++++--------- .../orttraining_test_ortmodule_fallback.py | 24 ++++++++++++------- 5 files changed, 46 insertions(+), 31 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index dbf59bc654..32890a32e1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -11,7 +11,7 @@ import warnings import gc from ._fallback import _FallbackManager, ORTModuleIOError, ORTModuleONNXModelException, wrap_exception - +from ._utils import warn_of_constant_inputs class _OutputIdentityOp(torch.autograd.Function): '''Internal class used to prepend Identity ops in model's outputs @@ -141,6 +141,9 @@ def _combine_input_buffers_initializers(params, onnx_input_names, input_info, bu # The exporter handles input lists by expanding them so that each # element of the list is its own input. # ORTModule must match this behavior by also expanding the inputs. + if current_input is None or isinstance(current_input, str): + # Drop all None and string inputs + return if isinstance(current_input, abc.Sequence): # If the input is a sequence (like a list), expand the list so that # each element of the list is an input by itself @@ -151,7 +154,7 @@ def _combine_input_buffers_initializers(params, onnx_input_names, input_info, bu # each element of the dict is an input by itself for _, val in current_input.items(): _expand_inputs(val, non_none_inputs) - elif current_input is not None: + else: # else just collect all the non none inputs within non_none_inputs non_none_inputs.append(current_input) @@ -314,8 +317,13 @@ def _extract_schema(data): """Extract the data schema by replacing every torch.Tensor value with _TensorStub""" if data is None: - return None + return data + elif isinstance(data, str): + warn_of_constant_inputs(data) + return data elif _PrimitiveType.is_primitive_type(data): + if isinstance(data, bool): + warn_of_constant_inputs(data) return _TensorStub(dtype=_PrimitiveType.get_primitive_dtype(data), shape_dims=0) # Depth first traversal to iterate over the data to replace every tensor with a stub elif isinstance(data, torch.Tensor): @@ -324,7 +332,7 @@ def _extract_schema(data): # Instead of replacing the tensor with a stub in the original user input, build the stubbed_schema # from scratch from the user input. stubbed_schema = None - if isinstance(data, abc.Sequence) and not isinstance(data, str): + if isinstance(data, abc.Sequence): sequence_type = type(data) stubbed_schema = [_extract_schema(val) for val in data] try: @@ -431,8 +439,8 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, input def _add_input(name, input, onnx_graph, onnx_graph_input_names): """Returns number of expanded non none inputs that _add_input processed""" - if input is None: - # Drop all None inputs and return 0. + if input is None or isinstance(input, str): + # Drop all None and string inputs and return 0. return 0 num_expanded_non_none_inputs = 0 diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 87aca7ce86..3775c17634 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -226,3 +226,7 @@ def switch_backend_to_pytorch(ortmodule, pytorch_module): ortmodule._load_state_dict_pre_hooks = pytorch_module._load_state_dict_pre_hooks ortmodule._modules = pytorch_module._modules ortmodule.forward = pytorch_module.forward + +def warn_of_constant_inputs(data): + warnings.warn(f"Received input of type {type(data)} which may be treated as a constant by ORT by default." + " Please consider moving constant arguments to the model constructor.") diff --git a/orttraining/orttraining/test/python/_orttraining_ortmodule_models.py b/orttraining/orttraining/test/python/_orttraining_ortmodule_models.py index cc2355ac3f..0e3607ab0e 100644 --- a/orttraining/orttraining/test/python/_orttraining_ortmodule_models.py +++ b/orttraining/orttraining/test/python/_orttraining_ortmodule_models.py @@ -5,10 +5,10 @@ import torch -class MyStrNet(torch.nn.Module): - def forward(self, x, my_str): - if my_str.lower() == 'hello': - print('hi') +class MyCustomClassInputNet(torch.nn.Module): + def forward(self, x, custom_class_obj): + if custom_class_obj.x == 1: + return x+1 return x diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 8b17df4a07..8d644fa8bb 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -318,7 +318,7 @@ class NeuralNetCustomClassOutput(torch.nn.Module): class MyStrNet(torch.nn.Module): def forward(self, x, my_str): if my_str.lower() == 'hello': - print('hi') + return x+1 return x @pytest.fixture(scope='session', autouse=True) @@ -3321,23 +3321,18 @@ def test_hf_save_pretrained(): for p1, p2 in zip(model1.parameters(), model2.parameters()): assert p1.data.ne(p2.data).sum() == 0 -def test_input_with_string_exception(): +def test_ortmodule_string_inputs_are_ignored(): pt_model = MyStrNet() ort_model = ORTModule(copy.deepcopy(pt_model)) x = torch.randn(1, 2) - from onnxruntime.training.ortmodule._fallback import _FallbackPolicy - if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA): - # Fallback - pt_out = pt_model(x, 'hello') - ort_out = pt_model(x, 'hello') - _test_helpers.assert_values_are_close(pt_out, ort_out) - else: - # ORT backend - with pytest.raises(_fallback.ORTModuleIOError) as ex_info: - _ = ort_model(x, 'hello') - assert "ORTModule does not support the following model data type " in str(ex_info.value) + with pytest.warns(UserWarning) as warning_record: + out = ort_model(x, 'hello') + + assert len(warning_record) == 2 + assert "Received input of type which may be treated as a constant by ORT by default." in warning_record[1].message.args[0] + _test_helpers.assert_values_are_close(out, x+1) def test_ortmodule_list_input(): class ListNet(torch.nn.Module): diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py index 250014846d..2183d2123b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py @@ -14,7 +14,7 @@ from onnxruntime.training.ortmodule.torch_cpp_extensions import is_installed as import _test_helpers from _orttraining_ortmodule_models import (NeuralNetSinglePositionalArgument, NeuralNetCustomClassOutput, - MyStrNet, + MyCustomClassInputNet, MyCustomFunctionReluModel) # PyTorch model definitions for tests @@ -254,10 +254,14 @@ def test_ortmodule_fallback_input(is_training, fallback_enabled, matching_policy os.environ['ORTMODULE_FALLBACK_POLICY'] = policy os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) - pt_model = MyStrNet() + pt_model = MyCustomClassInputNet() ort_model = ORTModule(copy.deepcopy(pt_model)) inputs = torch.randn(1, 2) + class CustomClass: + def __init__(self, x): + self.x = x + ort_model.train(is_training) pt_model.train(is_training) @@ -267,17 +271,21 @@ def test_ortmodule_fallback_input(is_training, fallback_enabled, matching_policy if i > 0 and persist_fallback: assert ort_model._torch_module._execution_manager( is_training=is_training)._fallback_manager._exception is not None - ort_out = ort_model(inputs, 'hello') - pt_out = pt_model(inputs, 'hello') + ort_out = ort_model(inputs, CustomClass(1)) + pt_out = pt_model(inputs, CustomClass(1)) _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0) else: with pytest.raises(_fallback.ORTModuleIOError) as ex_info: - _ = ort_model(torch.randn(1, 2), 'hello') - assert "ORTModule does not support the following model data type " in str(ex_info.value) + _ = ort_model(torch.randn(1, 2), CustomClass(1)) + assert "ORTModule does not support the following model data"\ + " type .CustomClass'>" in str(ex_info.value) else: with pytest.raises(_fallback.ORTModuleIOError) as ex_info: - _ = ort_model(torch.randn(1, 2), 'hello') - assert "ORTModule does not support the following model data type " in str(ex_info.value) + _ = ort_model(torch.randn(1, 2), CustomClass(1)) + assert "ORTModule does not support the following model data"\ + " type .CustomClass'>" in str(ex_info.value) @pytest.mark.parametrize("is_training,fallback_enabled,matching_policy,persist_fallback",