Ignore all string inputs to ORTModule AB#1310803 (#9344)

This commit is contained in:
baijumeswani 2021-10-19 16:34:47 -07:00 committed by GitHub
parent 4698b73725
commit 20eaed43e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 46 additions and 31 deletions

View file

@ -11,7 +11,7 @@ import warnings
import gc
from ._fallback import _FallbackManager, ORTModuleIOError, ORTModuleONNXModelException, wrap_exception
from ._utils import warn_of_constant_inputs
class _OutputIdentityOp(torch.autograd.Function):
'''Internal class used to prepend Identity ops in model's outputs
@ -141,6 +141,9 @@ def _combine_input_buffers_initializers(params, onnx_input_names, input_info, bu
# The exporter handles input lists by expanding them so that each
# element of the list is its own input.
# ORTModule must match this behavior by also expanding the inputs.
if current_input is None or isinstance(current_input, str):
# Drop all None and string inputs
return
if isinstance(current_input, abc.Sequence):
# If the input is a sequence (like a list), expand the list so that
# each element of the list is an input by itself
@ -151,7 +154,7 @@ def _combine_input_buffers_initializers(params, onnx_input_names, input_info, bu
# each element of the dict is an input by itself
for _, val in current_input.items():
_expand_inputs(val, non_none_inputs)
elif current_input is not None:
else:
# else just collect all the non none inputs within non_none_inputs
non_none_inputs.append(current_input)
@ -314,8 +317,13 @@ def _extract_schema(data):
"""Extract the data schema by replacing every torch.Tensor value with _TensorStub"""
if data is None:
return None
return data
elif isinstance(data, str):
warn_of_constant_inputs(data)
return data
elif _PrimitiveType.is_primitive_type(data):
if isinstance(data, bool):
warn_of_constant_inputs(data)
return _TensorStub(dtype=_PrimitiveType.get_primitive_dtype(data), shape_dims=0)
# Depth first traversal to iterate over the data to replace every tensor with a stub
elif isinstance(data, torch.Tensor):
@ -324,7 +332,7 @@ def _extract_schema(data):
# Instead of replacing the tensor with a stub in the original user input, build the stubbed_schema
# from scratch from the user input.
stubbed_schema = None
if isinstance(data, abc.Sequence) and not isinstance(data, str):
if isinstance(data, abc.Sequence):
sequence_type = type(data)
stubbed_schema = [_extract_schema(val) for val in data]
try:
@ -431,8 +439,8 @@ def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, input
def _add_input(name, input, onnx_graph, onnx_graph_input_names):
"""Returns number of expanded non none inputs that _add_input processed"""
if input is None:
# Drop all None inputs and return 0.
if input is None or isinstance(input, str):
# Drop all None and string inputs and return 0.
return 0
num_expanded_non_none_inputs = 0

View file

@ -226,3 +226,7 @@ def switch_backend_to_pytorch(ortmodule, pytorch_module):
ortmodule._load_state_dict_pre_hooks = pytorch_module._load_state_dict_pre_hooks
ortmodule._modules = pytorch_module._modules
ortmodule.forward = pytorch_module.forward
def warn_of_constant_inputs(data):
warnings.warn(f"Received input of type {type(data)} which may be treated as a constant by ORT by default."
" Please consider moving constant arguments to the model constructor.")

View file

@ -5,10 +5,10 @@
import torch
class MyStrNet(torch.nn.Module):
def forward(self, x, my_str):
if my_str.lower() == 'hello':
print('hi')
class MyCustomClassInputNet(torch.nn.Module):
def forward(self, x, custom_class_obj):
if custom_class_obj.x == 1:
return x+1
return x

View file

@ -318,7 +318,7 @@ class NeuralNetCustomClassOutput(torch.nn.Module):
class MyStrNet(torch.nn.Module):
def forward(self, x, my_str):
if my_str.lower() == 'hello':
print('hi')
return x+1
return x
@pytest.fixture(scope='session', autouse=True)
@ -3321,23 +3321,18 @@ def test_hf_save_pretrained():
for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert p1.data.ne(p2.data).sum() == 0
def test_input_with_string_exception():
def test_ortmodule_string_inputs_are_ignored():
pt_model = MyStrNet()
ort_model = ORTModule(copy.deepcopy(pt_model))
x = torch.randn(1, 2)
from onnxruntime.training.ortmodule._fallback import _FallbackPolicy
if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA):
# Fallback
pt_out = pt_model(x, 'hello')
ort_out = pt_model(x, 'hello')
_test_helpers.assert_values_are_close(pt_out, ort_out)
else:
# ORT backend
with pytest.raises(_fallback.ORTModuleIOError) as ex_info:
_ = ort_model(x, 'hello')
assert "ORTModule does not support the following model data type <class 'str'>" in str(ex_info.value)
with pytest.warns(UserWarning) as warning_record:
out = ort_model(x, 'hello')
assert len(warning_record) == 2
assert "Received input of type <class 'str'> which may be treated as a constant by ORT by default." in warning_record[1].message.args[0]
_test_helpers.assert_values_are_close(out, x+1)
def test_ortmodule_list_input():
class ListNet(torch.nn.Module):

View file

@ -14,7 +14,7 @@ from onnxruntime.training.ortmodule.torch_cpp_extensions import is_installed as
import _test_helpers
from _orttraining_ortmodule_models import (NeuralNetSinglePositionalArgument,
NeuralNetCustomClassOutput,
MyStrNet,
MyCustomClassInputNet,
MyCustomFunctionReluModel)
# PyTorch model definitions for tests
@ -254,10 +254,14 @@ def test_ortmodule_fallback_input(is_training, fallback_enabled, matching_policy
os.environ['ORTMODULE_FALLBACK_POLICY'] = policy
os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback)
pt_model = MyStrNet()
pt_model = MyCustomClassInputNet()
ort_model = ORTModule(copy.deepcopy(pt_model))
inputs = torch.randn(1, 2)
class CustomClass:
def __init__(self, x):
self.x = x
ort_model.train(is_training)
pt_model.train(is_training)
@ -267,17 +271,21 @@ def test_ortmodule_fallback_input(is_training, fallback_enabled, matching_policy
if i > 0 and persist_fallback:
assert ort_model._torch_module._execution_manager(
is_training=is_training)._fallback_manager._exception is not None
ort_out = ort_model(inputs, 'hello')
pt_out = pt_model(inputs, 'hello')
ort_out = ort_model(inputs, CustomClass(1))
pt_out = pt_model(inputs, CustomClass(1))
_test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0)
else:
with pytest.raises(_fallback.ORTModuleIOError) as ex_info:
_ = ort_model(torch.randn(1, 2), 'hello')
assert "ORTModule does not support the following model data type <class 'str'>" in str(ex_info.value)
_ = ort_model(torch.randn(1, 2), CustomClass(1))
assert "ORTModule does not support the following model data"\
" type <class 'orttraining_test_ortmodule_fallback."\
"test_ortmodule_fallback_input.<locals>.CustomClass'>" in str(ex_info.value)
else:
with pytest.raises(_fallback.ORTModuleIOError) as ex_info:
_ = ort_model(torch.randn(1, 2), 'hello')
assert "ORTModule does not support the following model data type <class 'str'>" in str(ex_info.value)
_ = ort_model(torch.randn(1, 2), CustomClass(1))
assert "ORTModule does not support the following model data"\
" type <class 'orttraining_test_ortmodule_fallback."\
"test_ortmodule_fallback_input.<locals>.CustomClass'>" in str(ex_info.value)
@pytest.mark.parametrize("is_training,fallback_enabled,matching_policy,persist_fallback",